Browse Source

Merge branch 'bootsteps_refactor'

Ask Solem 12 years ago
parent
commit
f70c4aacc7

+ 3 - 1
celery/app/base.py

@@ -11,7 +11,7 @@ from __future__ import absolute_import
 import threading
 import warnings
 
-from collections import deque
+from collections import defaultdict, deque
 from contextlib import contextmanager
 from copy import deepcopy
 from functools import wraps
@@ -72,6 +72,8 @@ class Celery(object):
         self.set_as_current = set_as_current
         self.registry_cls = symbol_by_name(self.registry_cls)
         self.accept_magic_kwargs = accept_magic_kwargs
+        self.user_options = defaultdict(set)
+        self.steps = defaultdict(set)
 
         self.configured = False
         self._pending_defaults = deque()

+ 6 - 5
celery/app/defaults.py

@@ -150,22 +150,23 @@ NAMESPACES = {
         'WORKER_DIRECT': Option(False, type='bool'),
     },
     'CELERYD': {
-        'AUTOSCALER': Option('celery.worker.autoscale.Autoscaler'),
-        'AUTORELOADER': Option('celery.worker.autoreload.Autoreloader'),
+        'AUTOSCALER': Option('celery.worker.autoscale:Autoscaler'),
+        'AUTORELOADER': Option('celery.worker.autoreload:Autoreloader'),
         'BOOT_STEPS': Option((), type='tuple'),
+        'CONSUMER_BOOT_STEPS': Option((), type='tuple'),
         'CONCURRENCY': Option(0, type='int'),
         'TIMER': Option(type='string'),
         'TIMER_PRECISION': Option(1.0, type='float'),
         'FORCE_EXECV': Option(True, type='bool'),
         'HIJACK_ROOT_LOGGER': Option(True, type='bool'),
-        'CONSUMER': Option(type='string'),
+        'CONSUMER': Option('celery.worker.consumer:Consumer', type='string'),
         'LOG_FORMAT': Option(DEFAULT_PROCESS_LOG_FMT),
         'LOG_COLOR': Option(type='bool'),
         'LOG_LEVEL': Option('WARN', deprecate_by='2.4', remove_by='4.0',
                             alt='--loglevel argument'),
         'LOG_FILE': Option(deprecate_by='2.4', remove_by='4.0',
                             alt='--logfile argument'),
-        'MEDIATOR': Option('celery.worker.mediator.Mediator'),
+        'MEDIATOR': Option('celery.worker.mediator:Mediator'),
         'MAX_TASKS_PER_CHILD': Option(type='int'),
         'POOL': Option(DEFAULT_POOL),
         'POOL_PUTLOCKS': Option(True, type='bool'),
@@ -179,7 +180,7 @@ NAMESPACES = {
     },
     'CELERYBEAT': {
         'SCHEDULE': Option({}, type='dict'),
-        'SCHEDULER': Option('celery.beat.PersistentScheduler'),
+        'SCHEDULER': Option('celery.beat:PersistentScheduler'),
         'SCHEDULE_FILENAME': Option('celerybeat-schedule'),
         'MAX_LOOP_INTERVAL': Option(0, type='float'),
         'LOG_LEVEL': Option('INFO', deprecate_by='2.4', remove_by='4.0',

+ 8 - 3
celery/apps/worker.py

@@ -101,6 +101,10 @@ class Worker(WorkController):
             enabled=not no_color if no_color is not None else no_color
         )
 
+    def on_init_namespace(self):
+        print('SETUP LOGGING: %r' % (self.redirect_stdouts, ))
+        self.setup_logging()
+
     def on_start(self):
         WorkController.on_start(self)
 
@@ -122,10 +126,11 @@ class Worker(WorkController):
 
         # Dump configuration to screen so we have some basic information
         # for when users sends bug reports.
-        print(str(self.colored.cyan(' \n', self.startup_info())) +
-              str(self.colored.reset(self.extra_info() or '')))
+        sys.__stdout__.write(
+            str(self.colored.cyan(' \n', self.startup_info())) +
+            str(self.colored.reset(self.extra_info() or '')) + '\n'
+        )
         self.set_process_status('-active-')
-        self.setup_logging()
         self.install_platform_tweaks(self)
 
     def on_consumer_ready(self, consumer):

+ 5 - 0
celery/bin/__init__.py

@@ -0,0 +1,5 @@
+from __future__ import absolute_import
+
+from collections import defaultdict
+
+from .base import Option  # noqa

+ 33 - 1
celery/bin/celery.py

@@ -10,6 +10,7 @@ from __future__ import absolute_import, print_function
 
 import anyjson
 import heapq
+import os
 import sys
 import warnings
 
@@ -26,6 +27,12 @@ from celery.utils.timeutils import maybe_iso8601
 
 from celery.bin.base import Command as BaseCommand, Option
 
+try:
+    # print_statement does not work with io.StringIO
+    from io import BytesIO as PrintIO
+except ImportError:
+    from StringIO import StringIO as PrintIO  # noqa
+
 HELP = """
 ---- -- - - ---- Commands- -------------- --- ------------
 
@@ -40,13 +47,18 @@ Migrating task {state.count}/{state.strtotal}: \
 {body[task]}[{body[id]}]\
 """
 
-commands = {}
+DEBUG = os.environ.get('C_DEBUG', False)
 
+commands = {}
 command_classes = [
     ('Main', ['worker', 'events', 'beat', 'shell', 'multi', 'amqp'], 'green'),
     ('Remote Control', ['status', 'inspect', 'control'], 'blue'),
     ('Utils', ['purge', 'list', 'migrate', 'call', 'result', 'report'], None),
 ]
+if DEBUG:
+    command_classes.append(
+        ('Debug', ['worker_graph', 'consumer_graph'], 'red'),
+    )
 
 
 @memoize()
@@ -458,6 +470,26 @@ class result(Command):
         self.out(self.prettify(value)[1])
 
 
+@command
+class worker_graph(Command):
+
+    def run(self, **kwargs):
+        worker = self.app.WorkController()
+        out = PrintIO()
+        worker.namespace.graph.to_dot(out)
+        self.out(out.getvalue())
+
+
+@command
+class consumer_graph(Command):
+
+    def run(self, **kwargs):
+        worker = self.app.WorkController()
+        out = PrintIO()
+        worker.consumer.namespace.graph.to_dot(out)
+        self.out(out.getvalue())
+
+
 class _RemoteControl(Command):
     name = None
     choices = None

+ 1 - 1
celery/bin/celeryd.py

@@ -197,7 +197,7 @@ class WorkerCommand(Command):
             Option('--autoreload', action='store_true'),
             Option('--no-execv', action='store_true', default=False),
             Option('-D', '--detach', action='store_true'),
-        ) + daemon_options()
+        ) + daemon_options() + tuple(self.app.user_options['worker'])
 
 
 def main():

+ 330 - 0
celery/bootsteps.py

@@ -0,0 +1,330 @@
+# -*- coding: utf-8 -*-
+"""
+    celery.bootsteps
+    ~~~~~~~~~~~~~~~~
+
+    The boot-steps!
+
+"""
+from __future__ import absolute_import
+
+from collections import deque
+from importlib import import_module
+from threading import Event
+
+from kombu.common import ignore_errors
+from kombu.utils import symbol_by_name
+
+from .datastructures import DependencyGraph
+from .utils.imports import instantiate, qualname, symbol_by_name
+from .utils.log import get_logger
+from .utils.threads import default_socket_timeout
+
+try:
+    from greenlet import GreenletExit
+    IGNORE_ERRORS = (GreenletExit, )
+except ImportError:  # pragma: no cover
+    IGNORE_ERRORS = ()
+
+#: Default socket timeout at shutdown.
+SHUTDOWN_SOCKET_TIMEOUT = 5.0
+
+#: States
+RUN = 0x1
+CLOSE = 0x2
+TERMINATE = 0x3
+
+logger = get_logger(__name__)
+debug = logger.debug
+
+
+def _pre(ns, fmt):
+    return '| {0}: {1}'.format(ns.alias, fmt)
+
+
+def _maybe_name(s):
+    if not isinstance(s, basestring):
+        return s.name
+    return s
+
+
+class Namespace(object):
+    """A namespace containing bootsteps.
+
+    :keyword steps: List of steps.
+    :keyword name: Set explicit name for this namespace.
+    :keyword app: Set the Celery app for this namespace.
+    :keyword on_start: Optional callback applied after namespace start.
+    :keyword on_close: Optional callback applied before namespace close.
+    :keyword on_stopped: Optional callback applied after namespace stopped.
+
+    """
+    name = None
+    state = None
+    started = 0
+    default_steps = set()
+
+    def __init__(self, steps=None, name=None, app=None, on_start=None,
+            on_close=None, on_stopped=None):
+        self.app = app
+        self.name = name or self.name or qualname(type(self))
+        self.types = set(steps or []) | set(self.default_steps)
+        self.on_start = on_start
+        self.on_close = on_close
+        self.on_stopped = on_stopped
+        self.shutdown_complete = Event()
+        self.steps = {}
+
+    def start(self, parent):
+        self.state = RUN
+        if self.on_start:
+            self.on_start()
+        for i, step in enumerate(filter(None, parent.steps)):
+            self._debug('Starting %s', step.alias)
+            self.started = i + 1
+            step.start(parent)
+            debug('^-- substep ok')
+
+    def close(self, parent):
+        if self.on_close:
+            self.on_close()
+        for step in parent.steps:
+            close = getattr(step, 'close', None)
+            if close:
+                close(parent)
+
+    def restart(self, parent, description='Restarting', attr='stop'):
+        with default_socket_timeout(SHUTDOWN_SOCKET_TIMEOUT):  # Issue 975
+            for step in reversed(parent.steps):
+                if step:
+                    self._debug('%s %s...', description, step.alias)
+                    fun = getattr(step, attr, None)
+                    if fun:
+                        fun(parent)
+
+    def stop(self, parent, close=True, terminate=False):
+        what = 'Terminating' if terminate else 'Stopping'
+        if self.state in (CLOSE, TERMINATE):
+            return
+
+        self.close(parent)
+
+        if self.state != RUN or self.started != len(parent.steps):
+            # Not fully started, can safely exit.
+            self.state = TERMINATE
+            self.shutdown_complete.set()
+            return
+        self.state = CLOSE
+        self.restart(parent, what, 'terminate' if terminate else 'stop')
+
+        if self.on_stopped:
+            self.on_stopped()
+        self.state = TERMINATE
+        self.shutdown_complete.set()
+
+    def join(self, timeout=None):
+        try:
+            # Will only get here if running green,
+            # makes sure all greenthreads have exited.
+            self.shutdown_complete.wait(timeout=timeout)
+        except IGNORE_ERRORS:
+            pass
+
+    def apply(self, parent, **kwargs):
+        """Apply the steps in this namespace to an object.
+
+        This will apply the ``__init__`` and ``include`` methods
+        of each steps with the object as argument.
+
+        For :class:`StartStopStep` the services created
+        will also be added the the objects ``steps`` attribute.
+
+        """
+        self._debug('Loading boot-steps.')
+        order = self.order = []
+        steps = self.steps = self.claim_steps()
+
+        self._debug('Building graph...')
+        for name in self._finalize_boot_steps(steps):
+            step = steps[name] = steps[name](parent, **kwargs)
+            order.append(step)
+            step.include(parent)
+        self._debug('New boot order: {%s}',
+                    ', '.join(s.alias for s in self.order))
+        return self
+
+    def import_module(self, module):
+        return import_module(module)
+
+    def __getitem__(self, name):
+        return self.steps[name]
+
+    def _find_last(self):
+        for C in self.steps.itervalues():
+            if C.last:
+                return C
+
+    def _firstpass(self, steps):
+        stream = deque(step.requires for step in steps.itervalues())
+        while stream:
+            for node in stream.popleft():
+                node = symbol_by_name(node)
+                if node.name not in self.steps:
+                    steps[node.name] = node
+                stream.append(node.requires)
+        for node in steps.itervalues():
+            node.requires = [_maybe_name(n) for n in node.requires]
+        for step in steps.values():
+            [steps[n] for n in step.requires]
+
+    def _finalize_boot_steps(self, steps):
+        self._firstpass(steps)
+        G = self.graph = DependencyGraph((C.name, C.requires)
+                            for C in steps.itervalues())
+        last = self._find_last()
+        if last:
+            for obj in G:
+                if obj != last.name:
+                    G.add_edge(last.name, obj)
+        try:
+            return G.topsort()
+        except KeyError as exc:
+            raise KeyError('unknown boot-step: %s' % exc)
+
+    def claim_steps(self):
+        return dict(self.load_step(step) for step in self._all_steps())
+
+    def _all_steps(self):
+        return self.types | self.app.steps[self.name]
+
+    def load_step(self, step):
+        step = symbol_by_name(step)
+        return step.name, step
+
+    def _debug(self, msg, *args):
+        return debug(_pre(self, msg), *args)
+
+    @property
+    def alias(self):
+        return self.name.rsplit('.', 1)[-1]
+
+
+class StepType(type):
+    """Metaclass for steps."""
+
+    def __new__(cls, name, bases, attrs):
+        module = attrs.get('__module__')
+        qname = '{0}.{1}'.format(module, name) if module else name
+        attrs.update(
+            __qualname__=qname,
+            name=attrs.get('name') or qname,
+            requires=attrs.get('requires', ()),
+        )
+        return super(StepType, cls).__new__(cls, name, bases, attrs)
+
+    def __repr__(self):
+        return 'step:{0.name}{{{0.requires!r}}}'.format(self)
+
+
+class Step(object):
+    """A Bootstep.
+
+    The :meth:`__init__` method is called when the step
+    is bound to a parent object, and can as such be used
+    to initialize attributes in the parent object at
+    parent instantiation-time.
+
+    """
+    __metaclass__ = StepType
+
+    #: Optional step name, will use qualname if not specified.
+    name = None
+
+    #: List of other steps that that must be started before this step.
+    #: Note that all dependencies must be in the same namespace.
+    requires = ()
+
+    #: Optional obj created by the :meth:`create` method.
+    #: This is used by :class:`StartStopStep` to keep the
+    #: original service object.
+    obj = None
+
+    #: This flag is reserved for the workers Consumer,
+    #: since it is required to always be started last.
+    #: There can only be one object marked with lsat
+    #: in every namespace.
+    last = False
+
+    #: This provides the default for :meth:`include_if`.
+    enabled = True
+
+    def __init__(self, parent, **kwargs):
+        pass
+
+    def create(self, parent):
+        """Create the step."""
+        pass
+
+    def include_if(self, parent):
+        """An optional predicate that decided whether this
+        step should be created."""
+        return self.enabled
+
+    def instantiate(self, name, *args, **kwargs):
+        return instantiate(name, *args, **kwargs)
+
+    def include(self, parent):
+        if self.include_if(parent):
+            self.obj = self.create(parent)
+            return True
+
+    def __repr__(self):
+        return '<step: {0.alias}>'.format(self)
+
+    @property
+    def alias(self):
+        return self.name.rsplit('.', 1)[-1]
+
+
+class StartStopStep(Step):
+
+    def start(self, parent):
+        if self.obj:
+            return self.obj.start()
+
+    def stop(self, parent):
+        if self.obj:
+            return self.obj.stop()
+
+    def close(self, parent):
+        pass
+
+    def terminate(self, parent):
+        self.stop(parent)
+
+    def include(self, parent):
+        if super(StartStopStep, self).include(parent):
+            parent.steps.append(self)
+
+
+class ConsumerStep(StartStopStep):
+    requires = ('Connection', )
+    consumers = None
+
+    def get_consumers(self, channel):
+        raise NotImplementedError('missing get_consumers')
+
+    def start(self, c):
+        self.consumers = self.get_consumers(c.connection)
+        for consumer in self.consumers or []:
+            consumer.consume()
+
+    def stop(self, c):
+        for consumer in self.consumers or []:
+            ignore_errors(c.connection, consumer.cancel)
+
+    def shutdown(self, c):
+        self.stop(c)
+        for consumer in self.consumers or []:
+            if consumer.channel:
+                ignore_errors(c.connection, consumer.channel.close)

+ 3 - 8
celery/tests/bin/test_celeryd.py

@@ -60,10 +60,7 @@ def disable_stdouts(fun):
 
 
 class Worker(cd.Worker):
-
-    def __init__(self, *args, **kwargs):
-        super(Worker, self).__init__(*args, **kwargs)
-        self.redirect_stdouts = False
+    redirect_stdouts = False
 
     def start(self, *args, **kwargs):
         self.on_start()
@@ -292,9 +289,7 @@ class test_Worker(WorkerAppCase):
 
     @disable_stdouts
     def test_redirect_stdouts(self):
-        worker = self.Worker()
-        worker.redirect_stdouts = False
-        worker.setup_logging()
+        self.Worker(redirect_stdouts=False)
         with self.assertRaises(AttributeError):
             sys.stdout.logger
 
@@ -306,7 +301,7 @@ class test_Worker(WorkerAppCase):
             logging_setup[0] = True
 
         try:
-            worker = self.Worker()
+            worker = self.Worker(redirect_stdouts=False)
             worker.app.log.__class__._setup = False
             worker.setup_logging()
             self.assertTrue(logging_setup[0])

+ 43 - 84
celery/tests/worker/test_bootsteps.py

@@ -2,37 +2,26 @@ from __future__ import absolute_import
 
 from mock import Mock
 
-from celery.worker import bootsteps
+from celery import bootsteps
 
 from celery.tests.utils import AppCase, Case
 
 
-class test_Component(Case):
+class test_Step(Case):
 
-    class Def(bootsteps.Component):
-        name = 'test_Component.Def'
-
-    def test_components_must_be_named(self):
-        with self.assertRaises(NotImplementedError):
-
-            class X(bootsteps.Component):
-                pass
-
-        class Y(bootsteps.Component):
-            abstract = True
+    class Def(bootsteps.Step):
+        name = 'test_Step.Def'
 
     def test_namespace_name(self, ns='test_namespace_name'):
 
-        class X(bootsteps.Component):
+        class X(bootsteps.Step):
             namespace = ns
             name = 'X'
-        self.assertEqual(X.namespace, ns)
         self.assertEqual(X.name, 'X')
 
-        class Y(bootsteps.Component):
-            name = '%s.Y' % (ns, )
-        self.assertEqual(Y.namespace, ns)
-        self.assertEqual(Y.name, 'Y')
+        class Y(bootsteps.Step):
+            name = '%s.Y' % ns
+        self.assertEqual(Y.name, '%s.Y' % ns)
 
     def test_init(self):
         self.assertTrue(self.Def(self))
@@ -70,13 +59,13 @@ class test_Component(Case):
         self.assertFalse(x.create.call_count)
 
 
-class test_StartStopComponent(Case):
+class test_StartStopStep(Case):
 
-    class Def(bootsteps.StartStopComponent):
-        name = 'test_StartStopComponent.Def'
+    class Def(bootsteps.StartStopStep):
+        name = 'test_StartStopStep.Def'
 
     def setUp(self):
-        self.components = []
+        self.steps = []
 
     def test_start__stop(self):
         x = self.Def(self)
@@ -84,10 +73,10 @@ class test_StartStopComponent(Case):
 
         # include creates the underlying object and sets
         # its x.obj attribute to it, as well as appending
-        # it to the parent.components list.
+        # it to the parent.steps list.
         x.include(self)
-        self.assertTrue(self.components)
-        self.assertIs(self.components[0], x)
+        self.assertTrue(self.steps)
+        self.assertIs(self.steps[0], x)
 
         x.start(self)
         x.obj.start.assert_called_with()
@@ -99,7 +88,7 @@ class test_StartStopComponent(Case):
         x = self.Def(self)
         x.enabled = False
         x.include(self)
-        self.assertFalse(self.components)
+        self.assertFalse(self.steps)
 
     def test_terminate(self):
         x = self.Def(self)
@@ -116,47 +105,29 @@ class test_Namespace(AppCase):
     class NS(bootsteps.Namespace):
         name = 'test_Namespace'
 
-    class ImportingNS(bootsteps.Namespace):
-
-        def __init__(self, *args, **kwargs):
-            bootsteps.Namespace.__init__(self, *args, **kwargs)
-            self.imported = []
-
-        def modules(self):
-            return ['A', 'B', 'C']
+    def test_steps_added_to_unclaimed(self):
 
-        def import_module(self, module):
-            self.imported.append(module)
-
-    def test_components_added_to_unclaimed(self):
-
-        class tnA(bootsteps.Component):
+        class tnA(bootsteps.Step):
             name = 'test_Namespace.A'
 
-        class tnB(bootsteps.Component):
+        class tnB(bootsteps.Step):
             name = 'test_Namespace.B'
 
-        class xxA(bootsteps.Component):
+        class xxA(bootsteps.Step):
             name = 'xx.A'
 
-        self.assertIn('A', self.NS._unclaimed['test_Namespace'])
-        self.assertIn('B', self.NS._unclaimed['test_Namespace'])
-        self.assertIn('A', self.NS._unclaimed['xx'])
-        self.assertNotIn('B', self.NS._unclaimed['xx'])
+        class NS(self.NS):
+            default_steps = [tnA, tnB]
+        ns = NS(app=self.app)
+
+        self.assertIn(tnA, ns._all_steps())
+        self.assertIn(tnB, ns._all_steps())
+        self.assertNotIn(xxA, ns._all_steps())
 
     def test_init(self):
         ns = self.NS(app=self.app)
         self.assertIs(ns.app, self.app)
         self.assertEqual(ns.name, 'test_Namespace')
-        self.assertFalse(ns.services)
-
-    def test_interface_modules(self):
-        self.NS(app=self.app).modules()
-
-    def test_load_modules(self):
-        x = self.ImportingNS(app=self.app)
-        x.load_modules()
-        self.assertListEqual(x.imported, ['A', 'B', 'C'])
 
     def test_apply(self):
 
@@ -166,44 +137,32 @@ class test_Namespace(AppCase):
             def modules(self):
                 return ['A', 'B']
 
-        class A(bootsteps.Component):
-            name = 'test_apply.A'
-            requires = ['C']
-
-        class B(bootsteps.Component):
+        class B(bootsteps.Step):
             name = 'test_apply.B'
 
-        class C(bootsteps.Component):
+        class C(bootsteps.Step):
             name = 'test_apply.C'
-            requires = ['B']
+            requires = [B]
 
-        class D(bootsteps.Component):
+        class A(bootsteps.Step):
+            name = 'test_apply.A'
+            requires = [C]
+
+        class D(bootsteps.Step):
             name = 'test_apply.D'
             last = True
 
-        x = MyNS(app=self.app)
-        x.import_module = Mock()
+        x = MyNS([A, D], app=self.app)
         x.apply(self)
 
-        self.assertItemsEqual(x.components.values(), [A, B, C, D])
-        self.assertTrue(x.import_module.call_count)
-
-        for boot_step in x.boot_steps:
-            self.assertEqual(boot_step.namespace, x)
-
-        self.assertIsInstance(x.boot_steps[0], B)
-        self.assertIsInstance(x.boot_steps[1], C)
-        self.assertIsInstance(x.boot_steps[2], A)
-        self.assertIsInstance(x.boot_steps[3], D)
-
-        self.assertIs(x['A'], A)
-
-    def test_import_module(self):
-        x = self.NS(app=self.app)
-        import os
-        self.assertIs(x.import_module('os'), os)
+        self.assertIsInstance(x.order[0], B)
+        self.assertIsInstance(x.order[1], C)
+        self.assertIsInstance(x.order[2], A)
+        self.assertIsInstance(x.order[3], D)
+        self.assertIn(A, x.types)
+        self.assertIs(x[A.name], x.order[2])
 
-    def test_find_last_but_no_components(self):
+    def test_find_last_but_no_steps(self):
 
         class MyNS(bootsteps.Namespace):
             name = 'qwejwioqjewoqiej'

+ 207 - 144
celery/tests/worker/test_worker.py

@@ -9,7 +9,7 @@ from Queue import Empty
 
 from billiard.exceptions import WorkerLostError
 from kombu import Connection
-from kombu.common import QoS, PREFETCH_COUNT_MAX
+from kombu.common import QoS, PREFETCH_COUNT_MAX, ignore_errors
 from kombu.exceptions import StdChannelError
 from kombu.transport.base import Message
 from mock import Mock, patch
@@ -17,6 +17,7 @@ from nose import SkipTest
 
 from celery import current_app
 from celery.app.defaults import DEFAULTS
+from celery.bootsteps import RUN, CLOSE, TERMINATE, StartStopStep
 from celery.concurrency.base import BasePool
 from celery.datastructures import AttributeDict
 from celery.exceptions import SystemTerminate
@@ -24,33 +25,51 @@ from celery.task import task as task_dec
 from celery.task import periodic_task as periodic_task_dec
 from celery.utils import uuid
 from celery.worker import WorkController
-from celery.worker.components import Queues, Timers, EvLoop, Pool
-from celery.worker.bootsteps import RUN, CLOSE, TERMINATE, StartStopComponent
+from celery.worker.components import Queues, Timers, Hub, Pool
 from celery.worker.buckets import FastQueue
 from celery.worker.job import Request
-from celery.worker.consumer import BlockingConsumer
+from celery.worker import consumer
+from celery.worker.consumer import Consumer
 from celery.utils.serialization import pickle
 from celery.utils.timer2 import Timer
 
 from celery.tests.utils import AppCase, Case
 
 
+def MockStep(step=None):
+    step = Mock() if step is None else step
+    step.namespace = Mock()
+    step.namespace.name = 'MockNS'
+    step.name = 'MockStep'
+    return step
+
+
 class PlaceHolder(object):
         pass
 
 
-class MyKombuConsumer(BlockingConsumer):
+def find_step(obj, typ):
+    return obj.namespace.steps[typ.name]
+
+
+class _MyKombuConsumer(Consumer):
     broadcast_consumer = Mock()
     task_consumer = Mock()
 
     def __init__(self, *args, **kwargs):
         kwargs.setdefault('pool', BasePool(2))
-        super(MyKombuConsumer, self).__init__(*args, **kwargs)
+        super(_MyKombuConsumer, self).__init__(*args, **kwargs)
 
     def restart_heartbeat(self):
         self.heart = None
 
 
+class MyKombuConsumer(Consumer):
+
+    def loop(self, *args, **kwargs):
+        pass
+
+
 class MockNode(object):
     commands = []
 
@@ -227,90 +246,102 @@ class test_Consumer(Case):
 
     def test_start_when_closed(self):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
-        l._state = CLOSE
+        l.namespace.state = CLOSE
         l.start()
 
     def test_connection(self):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
 
-        l.reset_connection()
+        l.namespace.start(l)
         self.assertIsInstance(l.connection, Connection)
 
-        l._state = RUN
+        l.namespace.state = RUN
         l.event_dispatcher = None
-        l.stop_consumers(close_connection=False)
+        l.namespace.restart(l)
         self.assertTrue(l.connection)
 
-        l._state = RUN
-        l.stop_consumers()
+        l.namespace.state = RUN
+        l.shutdown()
         self.assertIsNone(l.connection)
         self.assertIsNone(l.task_consumer)
 
-        l.reset_connection()
+        l.namespace.start(l)
         self.assertIsInstance(l.connection, Connection)
-        l.stop_consumers()
+        l.namespace.restart(l)
 
         l.stop()
-        l.close_connection()
+        l.shutdown()
         self.assertIsNone(l.connection)
         self.assertIsNone(l.task_consumer)
 
     def test_close_connection(self):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
-        l._state = RUN
-        l.close_connection()
+        l.namespace.state = RUN
+        step = find_step(l, consumer.Connection)
+        conn = l.connection = Mock()
+        step.shutdown(l)
+        self.assertTrue(conn.close.called)
+        self.assertIsNone(l.connection)
 
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         eventer = l.event_dispatcher = Mock()
         eventer.enabled = True
         heart = l.heart = MockHeart()
-        l._state = RUN
-        l.stop_consumers()
+        l.namespace.state = RUN
+        Events = find_step(l, consumer.Events)
+        Events.shutdown(l)
+        Heart = find_step(l, consumer.Heart)
+        Heart.shutdown(l)
         self.assertTrue(eventer.close.call_count)
         self.assertTrue(heart.closed)
 
     @patch('celery.worker.consumer.warn')
     def test_receive_message_unknown(self, warn):
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = _MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l.steps.pop()
         backend = Mock()
         m = create_message(backend, unknown={'baz': '!!!'})
         l.event_dispatcher = Mock()
-        l.pidbox_node = MockNode()
+        l.node = MockNode()
 
-        l.receive_message(m.decode(), m)
+        callback = self._get_on_message(l)
+        callback(m.decode(), m)
         self.assertTrue(warn.call_count)
 
-    @patch('celery.utils.timer2.to_timestamp')
+    @patch('celery.worker.consumer.to_timestamp')
     def test_receive_message_eta_OverflowError(self, to_timestamp):
         to_timestamp.side_effect = OverflowError()
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = _MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l.steps.pop()
         m = create_message(Mock(), task=foo_task.name,
                                    args=('2, 2'),
                                    kwargs={},
                                    eta=datetime.now().isoformat())
         l.event_dispatcher = Mock()
-        l.pidbox_node = MockNode()
+        l.node = MockNode()
         l.update_strategies()
 
-        l.receive_message(m.decode(), m)
+        callback = self._get_on_message(l)
+        callback(m.decode(), m)
         self.assertTrue(m.acknowledged)
         self.assertTrue(to_timestamp.call_count)
 
     @patch('celery.worker.consumer.error')
     def test_receive_message_InvalidTaskError(self, error):
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = _MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l.steps.pop()
         m = create_message(Mock(), task=foo_task.name,
                            args=(1, 2), kwargs='foobarbaz', id=1)
         l.update_strategies()
         l.event_dispatcher = Mock()
-        l.pidbox_node = MockNode()
 
-        l.receive_message(m.decode(), m)
+        callback = self._get_on_message(l)
+        callback(m.decode(), m)
         self.assertIn('Received invalid task message', error.call_args[0][0])
 
     @patch('celery.worker.consumer.crit')
     def test_on_decode_error(self, crit):
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = Consumer(self.ready_queue, timer=self.timer)
 
         class MockMessage(Mock):
             content_type = 'application/x-msgpack'
@@ -322,14 +353,25 @@ class test_Consumer(Case):
         self.assertTrue(message.ack.call_count)
         self.assertIn("Can't decode message body", crit.call_args[0][0])
 
+    def _get_on_message(self, l):
+        l.qos = Mock()
+        l.event_dispatcher = Mock()
+        l.task_consumer = Mock()
+        l.connection = Mock()
+        l.connection.drain_events.side_effect = SystemExit()
+
+        with self.assertRaises(SystemExit):
+            l.loop(*l.loop_args())
+        self.assertTrue(l.task_consumer.register_callback.called)
+        return l.task_consumer.register_callback.call_args[0][0]
+
     def test_receieve_message(self):
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = Consumer(self.ready_queue, timer=self.timer)
         m = create_message(Mock(), task=foo_task.name,
                            args=[2, 4, 8], kwargs={})
         l.update_strategies()
-
-        l.event_dispatcher = Mock()
-        l.receive_message(m.decode(), m)
+        callback = self._get_on_message(l)
+        callback(m.decode(), m)
 
         in_bucket = self.ready_queue.get_nowait()
         self.assertIsInstance(in_bucket, Request)
@@ -339,10 +381,10 @@ class test_Consumer(Case):
 
     def test_start_connection_error(self):
 
-        class MockConsumer(BlockingConsumer):
+        class MockConsumer(Consumer):
             iterations = 0
 
-            def consume_messages(self):
+            def loop(self, *args, **kwargs):
                 if not self.iterations:
                     self.iterations = 1
                     raise KeyError('foo')
@@ -360,10 +402,10 @@ class test_Consumer(Case):
         # Regression test for AMQPChannelExceptions that can occur within the
         # consumer. (i.e. 404 errors)
 
-        class MockConsumer(BlockingConsumer):
+        class MockConsumer(Consumer):
             iterations = 0
 
-            def consume_messages(self):
+            def loop(self, *args, **kwargs):
                 if not self.iterations:
                     self.iterations = 1
                     raise KeyError('foo')
@@ -377,7 +419,7 @@ class test_Consumer(Case):
         l.heart.stop()
         l.timer.stop()
 
-    def test_consume_messages_ignores_socket_timeout(self):
+    def test_loop_ignores_socket_timeout(self):
 
         class Connection(current_app.connection().__class__):
             obj = None
@@ -391,9 +433,9 @@ class test_Consumer(Case):
         l.task_consumer = Mock()
         l.connection.obj = l
         l.qos = QoS(l.task_consumer, 10)
-        l.consume_messages()
+        l.loop(*l.loop_args())
 
-    def test_consume_messages_when_socket_error(self):
+    def test_loop_when_socket_error(self):
 
         class Connection(current_app.connection().__class__):
             obj = None
@@ -402,20 +444,20 @@ class test_Consumer(Case):
                 self.obj.connection = None
                 raise socket.error('foo')
 
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
-        l._state = RUN
+        l = Consumer(self.ready_queue, timer=self.timer)
+        l.namespace.state = RUN
         c = l.connection = Connection()
         l.connection.obj = l
         l.task_consumer = Mock()
         l.qos = QoS(l.task_consumer, 10)
         with self.assertRaises(socket.error):
-            l.consume_messages()
+            l.loop(*l.loop_args())
 
-        l._state = CLOSE
+        l.namespace.state = CLOSE
         l.connection = c
-        l.consume_messages()
+        l.loop(*l.loop_args())
 
-    def test_consume_messages(self):
+    def test_loop(self):
 
         class Connection(current_app.connection().__class__):
             obj = None
@@ -423,14 +465,14 @@ class test_Consumer(Case):
             def drain_events(self, **kwargs):
                 self.obj.connection = None
 
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = Consumer(self.ready_queue, timer=self.timer)
         l.connection = Connection()
         l.connection.obj = l
         l.task_consumer = Mock()
         l.qos = QoS(l.task_consumer, 10)
 
-        l.consume_messages()
-        l.consume_messages()
+        l.loop(*l.loop_args())
+        l.loop(*l.loop_args())
         self.assertTrue(l.task_consumer.consume.call_count)
         l.task_consumer.qos.assert_called_with(prefetch_count=10)
         l.task_consumer.qos = Mock()
@@ -441,15 +483,15 @@ class test_Consumer(Case):
         self.assertEqual(l.qos.value, 9)
         l.task_consumer.qos.assert_called_with(prefetch_count=9)
 
-    def test_maybe_conn_error(self):
+    def test_ignore_errors(self):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         l.connection_errors = (KeyError, )
         l.channel_errors = (SyntaxError, )
-        l.maybe_conn_error(Mock(side_effect=AttributeError('foo')))
-        l.maybe_conn_error(Mock(side_effect=KeyError('foo')))
-        l.maybe_conn_error(Mock(side_effect=SyntaxError('foo')))
+        ignore_errors(l, Mock(side_effect=AttributeError('foo')))
+        ignore_errors(l, Mock(side_effect=KeyError('foo')))
+        ignore_errors(l, Mock(side_effect=SyntaxError('foo')))
         with self.assertRaises(IndexError):
-            l.maybe_conn_error(Mock(side_effect=IndexError('foo')))
+            ignore_errors(l, Mock(side_effect=IndexError('foo')))
 
     def test_apply_eta_task(self):
         from celery.worker import state
@@ -464,18 +506,20 @@ class test_Consumer(Case):
         self.assertIs(self.ready_queue.get_nowait(), task)
 
     def test_receieve_message_eta_isoformat(self):
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = _MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l.steps.pop()
         m = create_message(Mock(), task=foo_task.name,
                            eta=datetime.now().isoformat(),
                            args=[2, 4, 8], kwargs={})
 
         l.task_consumer = Mock()
-        l.qos = QoS(l.task_consumer, l.initial_prefetch_count)
+        l.qos = QoS(l.task_consumer, 1)
         current_pcount = l.qos.value
         l.event_dispatcher = Mock()
         l.enabled = False
         l.update_strategies()
-        l.receive_message(m.decode(), m)
+        callback = self._get_on_message(l)
+        callback(m.decode(), m)
         l.timer.stop()
         l.timer.join(1)
 
@@ -488,28 +532,30 @@ class test_Consumer(Case):
         self.assertGreater(l.qos.value, current_pcount)
         l.timer.stop()
 
-    def test_on_control(self):
+    def test_pidbox_callback(self):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
-        l.pidbox_node = Mock()
-        l.reset_pidbox_node = Mock()
+        con = find_step(l, consumer.Control).box
+        con.node = Mock()
+        con.reset = Mock()
 
-        l.on_control('foo', 'bar')
-        l.pidbox_node.handle_message.assert_called_with('foo', 'bar')
+        con.on_message('foo', 'bar')
+        con.node.handle_message.assert_called_with('foo', 'bar')
 
-        l.pidbox_node = Mock()
-        l.pidbox_node.handle_message.side_effect = KeyError('foo')
-        l.on_control('foo', 'bar')
-        l.pidbox_node.handle_message.assert_called_with('foo', 'bar')
+        con.node = Mock()
+        con.node.handle_message.side_effect = KeyError('foo')
+        con.on_message('foo', 'bar')
+        con.node.handle_message.assert_called_with('foo', 'bar')
 
-        l.pidbox_node = Mock()
-        l.pidbox_node.handle_message.side_effect = ValueError('foo')
-        l.on_control('foo', 'bar')
-        l.pidbox_node.handle_message.assert_called_with('foo', 'bar')
-        l.reset_pidbox_node.assert_called_with()
+        con.node = Mock()
+        con.node.handle_message.side_effect = ValueError('foo')
+        con.on_message('foo', 'bar')
+        con.node.handle_message.assert_called_with('foo', 'bar')
+        self.assertTrue(con.reset.called)
 
     def test_revoke(self):
         ready_queue = FastQueue()
-        l = MyKombuConsumer(ready_queue, timer=self.timer)
+        l = _MyKombuConsumer(ready_queue, timer=self.timer)
+        l.steps.pop()
         backend = Mock()
         id = uuid()
         t = create_message(backend, task=foo_task.name, args=[2, 4, 8],
@@ -517,16 +563,19 @@ class test_Consumer(Case):
         from celery.worker.state import revoked
         revoked.add(id)
 
-        l.receive_message(t.decode(), t)
+        callback = self._get_on_message(l)
+        callback(t.decode(), t)
         self.assertTrue(ready_queue.empty())
 
     def test_receieve_message_not_registered(self):
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = _MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l.steps.pop()
         backend = Mock()
         m = create_message(backend, task='x.X.31x', args=[2, 4, 8], kwargs={})
 
         l.event_dispatcher = Mock()
-        self.assertFalse(l.receive_message(m.decode(), m))
+        callback = self._get_on_message(l)
+        self.assertFalse(callback(m.decode(), m))
         with self.assertRaises(Empty):
             self.ready_queue.get_nowait()
         self.assertTrue(self.timer.empty())
@@ -534,7 +583,7 @@ class test_Consumer(Case):
     @patch('celery.worker.consumer.warn')
     @patch('celery.worker.consumer.logger')
     def test_receieve_message_ack_raises(self, logger, warn):
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = Consumer(self.ready_queue, timer=self.timer)
         backend = Mock()
         m = create_message(backend, args=[2, 4, 8], kwargs={})
 
@@ -542,7 +591,8 @@ class test_Consumer(Case):
         l.connection_errors = (socket.error, )
         m.reject = Mock()
         m.reject.side_effect = socket.error('foo')
-        self.assertFalse(l.receive_message(m.decode(), m))
+        callback = self._get_on_message(l)
+        self.assertFalse(callback(m.decode(), m))
         self.assertTrue(warn.call_count)
         with self.assertRaises(Empty):
             self.ready_queue.get_nowait()
@@ -551,7 +601,8 @@ class test_Consumer(Case):
         self.assertTrue(logger.critical.call_count)
 
     def test_receieve_message_eta(self):
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = _MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l.steps.pop()
         l.event_dispatcher = Mock()
         l.event_dispatcher._outbound_buffer = deque()
         backend = Mock()
@@ -560,16 +611,17 @@ class test_Consumer(Case):
                            eta=(datetime.now() +
                                timedelta(days=1)).isoformat())
 
-        l.reset_connection()
+        l.namespace.start(l)
         p = l.app.conf.BROKER_CONNECTION_RETRY
         l.app.conf.BROKER_CONNECTION_RETRY = False
         try:
-            l.reset_connection()
+            l.namespace.start(l)
         finally:
             l.app.conf.BROKER_CONNECTION_RETRY = p
-        l.stop_consumers()
+        l.namespace.restart(l)
         l.event_dispatcher = Mock()
-        l.receive_message(m.decode(), m)
+        callback = self._get_on_message(l)
+        callback(m.decode(), m)
         l.timer.stop()
         in_hold = l.timer.queue[0]
         self.assertEqual(len(in_hold), 3)
@@ -583,24 +635,34 @@ class test_Consumer(Case):
 
     def test_reset_pidbox_node(self):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
-        l.pidbox_node = Mock()
-        chan = l.pidbox_node.channel = Mock()
+        con = find_step(l, consumer.Control).box
+        con.node = Mock()
+        chan = con.node.channel = Mock()
         l.connection = Mock()
         chan.close.side_effect = socket.error('foo')
         l.connection_errors = (socket.error, )
-        l.reset_pidbox_node()
+        con.reset()
         chan.close.assert_called_with()
 
     def test_reset_pidbox_node_green(self):
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
-        l.pool = Mock()
-        l.pool.is_green = True
-        l.reset_pidbox_node()
-        l.pool.spawn_n.assert_called_with(l._green_pidbox_node)
+        from celery.worker.pidbox import gPidbox
+        pool = Mock()
+        pool.is_green = True
+        l = MyKombuConsumer(self.ready_queue, timer=self.timer, pool=pool)
+        con = find_step(l, consumer.Control)
+        self.assertIsInstance(con.box, gPidbox)
+        con.start(l)
+        l.pool.spawn_n.assert_called_with(
+            con.box.loop, l,
+        )
 
     def test__green_pidbox_node(self):
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
-        l.pidbox_node = Mock()
+        pool = Mock()
+        pool.is_green = True
+        l = MyKombuConsumer(self.ready_queue, timer=self.timer, pool=pool)
+        l.node = Mock()
+        controller = find_step(l, consumer.Control)
+        box = controller.box
 
         class BConsumer(Mock):
 
@@ -611,7 +673,7 @@ class test_Consumer(Case):
             def __exit__(self, *exc_info):
                 self.cancel()
 
-        l.pidbox_node.listen = BConsumer()
+        controller.box.node.listen = BConsumer()
         connections = []
 
         class Connection(object):
@@ -640,25 +702,26 @@ class test_Consumer(Case):
                     self.calls += 1
                     raise socket.timeout()
                 self.obj.connection = None
-                self.obj._pidbox_node_shutdown.set()
+                controller.box._node_shutdown.set()
 
             def close(self):
                 self.closed = True
 
         l.connection = Mock()
-        l._open_connection = lambda: Connection(obj=l)
-        l._green_pidbox_node()
+        l.connect = lambda: Connection(obj=l)
+        controller = find_step(l, consumer.Control)
+        controller.box.loop(l)
 
-        l.pidbox_node.listen.assert_called_with(callback=l.on_control)
-        self.assertTrue(l.broadcast_consumer)
-        l.broadcast_consumer.consume.assert_called_with()
+        self.assertTrue(controller.box.node.listen.called)
+        self.assertTrue(controller.box.consumer)
+        controller.box.consumer.consume.assert_called_with()
 
         self.assertIsNone(l.connection)
         self.assertTrue(connections[0].closed)
 
     @patch('kombu.connection.Connection._establish_connection')
     @patch('kombu.utils.sleep')
-    def test_open_connection_errback(self, sleep, connect):
+    def test_connect_errback(self, sleep, connect):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         from kombu.transport.memory import Transport
         Transport.connection_errors = (StdChannelError, )
@@ -668,17 +731,18 @@ class test_Consumer(Case):
                 return
             raise StdChannelError()
         connect.side_effect = effect
-        l._open_connection()
+        l.connect()
         connect.assert_called_with()
 
     def test_stop_pidbox_node(self):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
-        l._pidbox_node_stopped = Event()
-        l._pidbox_node_shutdown = Event()
-        l._pidbox_node_stopped.set()
-        l.stop_pidbox_node()
+        cont = find_step(l, consumer.Control)
+        cont._node_stopped = Event()
+        cont._node_shutdown = Event()
+        cont._node_stopped.set()
+        cont.stop(l)
 
-    def test_start__consume_messages(self):
+    def test_start__loop(self):
 
         class _QoS(object):
             prev = 3
@@ -703,18 +767,17 @@ class test_Consumer(Case):
         l.connection = Connection()
         l.iterations = 0
 
-        def raises_KeyError(limit=None):
+        def raises_KeyError(*args, **kwargs):
             l.iterations += 1
             if l.qos.prev != l.qos.value:
                 l.qos.update()
             if l.iterations >= 2:
                 raise KeyError('foo')
 
-        l.consume_messages = raises_KeyError
+        l.loop = raises_KeyError
         with self.assertRaises(KeyError):
             l.start()
-        self.assertTrue(init_callback.call_count)
-        self.assertEqual(l.iterations, 1)
+        self.assertEqual(l.iterations, 2)
         self.assertEqual(l.qos.prev, l.qos.value)
 
         init_callback.reset_mock()
@@ -724,25 +787,25 @@ class test_Consumer(Case):
         l.task_consumer = Mock()
         l.broadcast_consumer = Mock()
         l.connection = Connection()
-        l.consume_messages = Mock(side_effect=socket.error('foo'))
+        l.loop = Mock(side_effect=socket.error('foo'))
         with self.assertRaises(socket.error):
             l.start()
-        self.assertTrue(init_callback.call_count)
-        self.assertTrue(l.consume_messages.call_count)
+        self.assertTrue(l.loop.call_count)
 
     def test_reset_connection_with_no_node(self):
-        l = BlockingConsumer(self.ready_queue, timer=self.timer)
+        l = Consumer(self.ready_queue, timer=self.timer)
+        l.steps.pop()
         self.assertEqual(None, l.pool)
-        l.reset_connection()
+        l.namespace.start(l)
 
     def test_on_task_revoked(self):
-        l = BlockingConsumer(self.ready_queue, timer=self.timer)
+        l = Consumer(self.ready_queue, timer=self.timer)
         task = Mock()
         task.revoked.return_value = True
         l.on_task(task)
 
     def test_on_task_no_events(self):
-        l = BlockingConsumer(self.ready_queue, timer=self.timer)
+        l = Consumer(self.ready_queue, timer=self.timer)
         task = Mock()
         task.revoked.return_value = False
         l.event_dispatcher = Mock()
@@ -778,7 +841,7 @@ class test_WorkController(AppCase):
     def test_use_pidfile(self, create_pidlock):
         create_pidlock.return_value = Mock()
         worker = self.create_worker(pidfile='pidfilelockfilepid')
-        worker.components = []
+        worker.steps = []
         worker.start()
         self.assertTrue(create_pidlock.called)
         worker.stop()
@@ -825,12 +888,12 @@ class test_WorkController(AppCase):
         self.assertTrue(worker.pool)
         self.assertTrue(worker.consumer)
         self.assertTrue(worker.mediator)
-        self.assertTrue(worker.components)
+        self.assertTrue(worker.steps)
 
     def test_with_embedded_celerybeat(self):
         worker = WorkController(concurrency=1, loglevel=0, beat=True)
         self.assertTrue(worker.beat)
-        self.assertIn(worker.beat, [w.obj for w in worker.components])
+        self.assertIn(worker.beat, [w.obj for w in worker.steps])
 
     def test_with_autoscaler(self):
         worker = self.create_worker(autoscale=[10, 3], send_events=False,
@@ -892,7 +955,7 @@ class test_WorkController(AppCase):
         m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
                            kwargs={})
         task = Request.from_message(m, m.decode())
-        worker.components = []
+        worker.steps = []
         worker.namespace.state = RUN
         with self.assertRaises(KeyboardInterrupt):
             worker.process_task(task)
@@ -906,7 +969,7 @@ class test_WorkController(AppCase):
         m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
                            kwargs={})
         task = Request.from_message(m, m.decode())
-        worker.components = []
+        worker.steps = []
         worker.namespace.state = RUN
         with self.assertRaises(SystemExit):
             worker.process_task(task)
@@ -925,17 +988,18 @@ class test_WorkController(AppCase):
 
     def test_start_catches_base_exceptions(self):
         worker1 = self.create_worker()
-        stc = Mock()
+        stc = MockStep()
         stc.start.side_effect = SystemTerminate()
-        worker1.components = [stc]
+        worker1.steps = [stc]
         worker1.start()
+        stc.start.assert_called_with(worker1)
         self.assertTrue(stc.terminate.call_count)
 
         worker2 = self.create_worker()
-        sec = Mock()
+        sec = MockStep()
         sec.start.side_effect = SystemExit()
         sec.terminate = None
-        worker2.components = [sec]
+        worker2.steps = [sec]
         worker2.start()
         self.assertTrue(sec.stop.call_count)
 
@@ -988,18 +1052,18 @@ class test_WorkController(AppCase):
     def test_start__stop(self):
         worker = self.worker
         worker.namespace.shutdown_complete.set()
-        worker.components = [StartStopComponent(self) for _ in range(4)]
+        worker.steps = [MockStep(StartStopStep(self)) for _ in range(4)]
         worker.namespace.state = RUN
         worker.namespace.started = 4
-        for w in worker.components:
+        for w in worker.steps:
             w.start = Mock()
             w.stop = Mock()
 
         worker.start()
-        for w in worker.components:
+        for w in worker.steps:
             self.assertTrue(w.start.call_count)
         worker.stop()
-        for w in worker.components:
+        for w in worker.steps:
             self.assertTrue(w.stop.call_count)
 
         # Doesn't close pool if no pool.
@@ -1008,15 +1072,15 @@ class test_WorkController(AppCase):
         worker.stop()
 
         # test that stop of None is not attempted
-        worker.components[-1] = None
+        worker.steps[-1] = None
         worker.start()
         worker.stop()
 
-    def test_component_raises(self):
+    def test_step_raises(self):
         worker = self.worker
-        comp = Mock()
-        worker.components = [comp]
-        comp.start.side_effect = TypeError()
+        step = Mock()
+        worker.steps = [step]
+        step.start.side_effect = TypeError()
         worker.stop = Mock()
         worker.start()
         worker.stop.assert_called_with()
@@ -1029,16 +1093,15 @@ class test_WorkController(AppCase):
         worker.namespace.shutdown_complete.set()
         worker.namespace.started = 5
         worker.namespace.state = RUN
-        worker.components = [Mock(), Mock(), Mock(), Mock(), Mock()]
-
+        worker.steps = [MockStep() for _ in range(5)]
         worker.start()
-        for w in worker.components[:3]:
+        for w in worker.steps[:3]:
             self.assertTrue(w.start.call_count)
-        self.assertTrue(worker.namespace.started, len(worker.components))
+        self.assertTrue(worker.namespace.started, len(worker.steps))
         self.assertEqual(worker.namespace.state, RUN)
         worker.terminate()
-        for component in worker.components:
-            self.assertTrue(component.terminate.call_count)
+        for step in worker.steps:
+            self.assertTrue(step.terminate.call_count)
 
     def test_Queues_pool_not_rlimit_safe(self):
         w = Mock()
@@ -1052,9 +1115,9 @@ class test_WorkController(AppCase):
         Queues(w).create(w)
         self.assertIs(w.ready_queue.put, w.process_task)
 
-    def test_EvLoop_crate(self):
+    def test_Hub_crate(self):
         w = Mock()
-        x = EvLoop(w)
+        x = Hub(w)
         hub = x.create(w)
         self.assertTrue(w.timer.max_interval)
         self.assertIs(w.hub, hub)

+ 4 - 12
celery/utils/imports.py

@@ -24,18 +24,10 @@ class NotAPackage(Exception):
     pass
 
 
-if sys.version_info >= (3, 3):  # pragma: no cover
-
-    def qualname(obj):
-        return obj.__qualname__
-
-else:
-
-    def qualname(obj):  # noqa
-        if not hasattr(obj, '__name__') and hasattr(obj, '__class__'):
-            return qualname(obj.__class__)
-
-        return '%s.%s' % (obj.__module__, obj.__name__)
+def qualname(obj):  # noqa
+    if not hasattr(obj, '__name__') and hasattr(obj, '__class__'):
+        obj = obj.__class__
+    return '%s.%s' % (obj.__module__, obj.__name__)
 
 
 def instantiate(name, *args, **kwargs):

+ 2 - 0
celery/utils/text.py

@@ -13,6 +13,8 @@ from textwrap import fill
 
 from pprint import pformat
 
+from kombu.utils.encoding import safe_repr
+
 
 def dedent_initial(s, n=4):
     return s[n:] if s[:n] == ' ' * n else s

+ 11 - 0
celery/utils/threads.py

@@ -9,10 +9,13 @@
 from __future__ import absolute_import, print_function
 
 import os
+import socket
 import sys
 import threading
 import traceback
 
+from contextlib import contextmanager
+
 from celery.local import Proxy
 from celery.utils.compat import THREAD_TIMEOUT_MAX
 
@@ -284,6 +287,14 @@ class LocalManager(object):
             self.__class__.__name__, len(self.locals))
 
 
+@contextmanager
+def default_socket_timeout(timeout):
+    prev = socket.getdefaulttimeout()
+    socket.setdefaulttimeout(timeout)
+    yield
+    socket.setdefaulttimeout(prev)
+
+
 class _FastLocalStack(threading.local):
 
     def __init__(self):

+ 35 - 26
celery/worker/__init__.py

@@ -6,7 +6,7 @@
     :class:`WorkController` can be used to instantiate in-process workers.
 
     The worker consists of several components, all managed by boot-steps
-    (mod:`celery.worker.bootsteps`).
+    (mod:`celery.bootsteps`).
 
 """
 from __future__ import absolute_import
@@ -19,6 +19,7 @@ from billiard import cpu_count
 from kombu.syn import detect_environment
 from kombu.utils.finalize import Finalize
 
+from celery import bootsteps
 from celery import concurrency as _concurrency
 from celery import platforms
 from celery import signals
@@ -31,7 +32,6 @@ from celery.utils import worker_direct
 from celery.utils.imports import reload_from_cwd
 from celery.utils.log import mlevel, worker_logger as logger
 
-from . import bootsteps
 from . import state
 
 UNKNOWN_QUEUE = """\
@@ -43,24 +43,6 @@ enable the CELERY_CREATE_MISSING_QUEUES setting.
 """
 
 
-class Namespace(bootsteps.Namespace):
-    """This is the boot-step namespace of the :class:`WorkController`.
-
-    It loads modules from :setting:`CELERYD_BOOT_STEPS`, and its
-    own set of built-in boot-step modules.
-
-    """
-    name = 'worker'
-    builtin_boot_steps = ('celery.worker.components',
-                          'celery.worker.autoscale',
-                          'celery.worker.autoreload',
-                          'celery.worker.consumer',
-                          'celery.worker.mediator')
-
-    def modules(self):
-        return self.builtin_boot_steps + self.app.conf.CELERYD_BOOT_STEPS
-
-
 class WorkController(configurated):
     """Unmanaged worker instance."""
     app = None
@@ -90,6 +72,28 @@ class WorkController(configurated):
 
     pidlock = None
 
+    class Namespace(bootsteps.Namespace):
+        """This is the boot-step namespace of the :class:`WorkController`.
+
+        It loads modules from :setting:`CELERYD_BOOT_STEPS`, and its
+        own set of built-in boot-step modules.
+
+        """
+        name = 'Worker'
+        default_steps = set([
+            'celery.worker.components:Hub',
+            'celery.worker.components:Queues',
+            'celery.worker.components:Pool',
+            'celery.worker.components:Beat',
+            'celery.worker.components:Timers',
+            'celery.worker.components:StateDB',
+            'celery.worker.components:Consumer',
+            'celery.worker.autoscale:WorkerComponent',
+            'celery.worker.autoreload:WorkerComponent',
+            'celery.worker.mediator:WorkerComponent',
+
+        ])
+
     def __init__(self, app=None, hostname=None, **kwargs):
         self.app = app_or_default(app or self.app)
         self.hostname = hostname or socket.gethostname()
@@ -117,18 +121,23 @@ class WorkController(configurated):
         self.loglevel = mlevel(self.loglevel)
         self.ready_callback = ready_callback or self.on_consumer_ready
         self.use_eventloop = self.should_use_eventloop()
+        self.options = kwargs
 
         signals.worker_init.send(sender=self)
 
         # Initialize boot steps
         self.pool_cls = _concurrency.get_implementation(self.pool_cls)
-        self.components = []
-        self.namespace = Namespace(app=self.app,
-                                   on_start=self.on_start,
-                                   on_close=self.on_close,
-                                   on_stopped=self.on_stopped)
+        self.steps = []
+        self.on_init_namespace()
+        self.namespace = self.Namespace(app=self.app,
+                                        on_start=self.on_start,
+                                        on_close=self.on_close,
+                                        on_stopped=self.on_stopped)
         self.namespace.apply(self, **kwargs)
 
+    def on_init_namespace(self):
+        pass
+
     def on_before_init(self, **kwargs):
         pass
 
@@ -144,7 +153,7 @@ class WorkController(configurated):
 
     def on_stopped(self):
         self.timer.stop()
-        self.consumer.close_connection()
+        self.consumer.shutdown()
 
         if self.pidlock:
             self.pidlock.release()

+ 4 - 4
celery/worker/autoreload.py

@@ -18,12 +18,13 @@ from threading import Event
 
 from kombu.utils import eventio
 
+from celery import bootsteps
 from celery.platforms import ignore_errno
 from celery.utils.imports import module_file
 from celery.utils.log import get_logger
 from celery.utils.threads import bgThread
 
-from .bootsteps import StartStopComponent
+from .components import Pool
 
 try:                        # pragma: no cover
     import pyinotify
@@ -35,9 +36,8 @@ except ImportError:         # pragma: no cover
 logger = get_logger(__name__)
 
 
-class WorkerComponent(StartStopComponent):
-    name = 'worker.autoreloader'
-    requires = ('pool', )
+class WorkerComponent(bootsteps.StartStopStep):
+    requires = (Pool, )
 
     def __init__(self, w, autoreload=None, **kwargs):
         self.enabled = w.autoreload = autoreload

+ 4 - 4
celery/worker/autoscale.py

@@ -18,20 +18,20 @@ import threading
 from functools import partial
 from time import sleep, time
 
+from celery import bootsteps
 from celery.utils.log import get_logger
 from celery.utils.threads import bgThread
 
 from . import state
-from .bootsteps import StartStopComponent
+from .components import Pool
 from .hub import DummyLock
 
 logger = get_logger(__name__)
 debug, info, error = logger.debug, logger.info, logger.error
 
 
-class WorkerComponent(StartStopComponent):
-    name = 'worker.autoscaler'
-    requires = ('pool', )
+class WorkerComponent(bootsteps.StartStopStep):
+    requires = (Pool, )
 
     def __init__(self, w, **kwargs):
         self.enabled = w.autoscale

+ 0 - 292
celery/worker/bootsteps.py

@@ -1,292 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-    celery.worker.bootsteps
-    ~~~~~~~~~~~~~~~~~~~~~~~
-
-    The boot-step components.
-
-"""
-from __future__ import absolute_import
-
-import socket
-
-from collections import defaultdict
-from importlib import import_module
-from threading import Event
-
-from celery.datastructures import DependencyGraph
-from celery.utils.imports import instantiate, qualname
-from celery.utils.log import get_logger
-
-try:
-    from greenlet import GreenletExit
-    IGNORE_ERRORS = (GreenletExit, )
-except ImportError:  # pragma: no cover
-    IGNORE_ERRORS = ()
-
-#: Default socket timeout at shutdown.
-SHUTDOWN_SOCKET_TIMEOUT = 5.0
-
-#: States
-RUN = 0x1
-CLOSE = 0x2
-TERMINATE = 0x3
-
-logger = get_logger(__name__)
-
-
-class Namespace(object):
-    """A namespace containing components.
-
-    Every component must belong to a namespace.
-
-    When component classes are created they are added to the
-    mapping of unclaimed components.  The components will be
-    claimed when the namespace they belong to is created.
-
-    :keyword name: Set the name of this namespace.
-    :keyword app: Set the Celery app for this namespace.
-
-    """
-    name = None
-    state = None
-    started = 0
-
-    _unclaimed = defaultdict(dict)
-
-    def __init__(self, name=None, app=None, on_start=None,
-            on_close=None, on_stopped=None):
-        self.app = app
-        self.name = name or self.name
-        self.on_start = on_start
-        self.on_close = on_close
-        self.on_stopped = on_stopped
-        self.services = []
-        self.shutdown_complete = Event()
-
-    def start(self, parent):
-        self.state = RUN
-        if self.on_start:
-            self.on_start()
-        for i, component in enumerate(parent.components):
-            if component:
-                logger.debug('Starting %s...', qualname(component))
-                self.started = i + 1
-                component.start(parent)
-                logger.debug('%s OK!', qualname(component))
-
-    def close(self, parent):
-        if self.on_close:
-            self.on_close()
-        for component in parent.components:
-            try:
-                close = component.close
-            except AttributeError:
-                pass
-            else:
-                close(parent)
-
-    def stop(self, parent, terminate=False):
-        what = 'Terminating' if terminate else 'Stopping'
-        socket_timeout = socket.getdefaulttimeout()
-        socket.setdefaulttimeout(SHUTDOWN_SOCKET_TIMEOUT)  # Issue 975
-
-        if self.state in (CLOSE, TERMINATE):
-            return
-
-        self.close(parent)
-
-        if self.state != RUN or self.started != len(parent.components):
-            # Not fully started, can safely exit.
-            self.state = TERMINATE
-            self.shutdown_complete.set()
-            return
-        self.state = CLOSE
-
-        for component in reversed(parent.components):
-            if component:
-                logger.debug('%s %s...', what, qualname(component))
-                (component.terminate if terminate else component.stop)(parent)
-
-        if self.on_stopped:
-            self.on_stopped()
-        self.state = TERMINATE
-        socket.setdefaulttimeout(socket_timeout)
-        self.shutdown_complete.set()
-
-    def join(self, timeout=None):
-        try:
-            # Will only get here if running green,
-            # makes sure all greenthreads have exited.
-            self.shutdown_complete.wait(timeout=timeout)
-        except IGNORE_ERRORS:
-            pass
-
-    def modules(self):
-        """Subclasses can override this to return a
-        list of modules to import before components are claimed."""
-        return []
-
-    def load_modules(self):
-        """Will load the component modules this namespace depends on."""
-        for m in self.modules():
-            self.import_module(m)
-
-    def apply(self, parent, **kwargs):
-        """Apply the components in this namespace to an object.
-
-        This will apply the ``__init__`` and ``include`` methods
-        of each components with the object as argument.
-
-        For ``StartStopComponents`` the services created
-        will also be added the the objects ``components`` attribute.
-
-        """
-        self._debug('Loading modules.')
-        self.load_modules()
-        self._debug('Claiming components.')
-        self.components = self._claim()
-        self._debug('Building boot step graph.')
-        self.boot_steps = [self.bind_component(name, parent, **kwargs)
-                                for name in self._finalize_boot_steps()]
-        self._debug('New boot order: {%s}',
-                ', '.join(c.name for c in self.boot_steps))
-
-        for component in self.boot_steps:
-            component.include(parent)
-        return self
-
-    def bind_component(self, name, parent, **kwargs):
-        """Bind component to parent object and this namespace."""
-        comp = self[name](parent, **kwargs)
-        comp.namespace = self
-        return comp
-
-    def import_module(self, module):
-        return import_module(module)
-
-    def __getitem__(self, name):
-        return self.components[name]
-
-    def _find_last(self):
-        for C in self.components.itervalues():
-            if C.last:
-                return C
-
-    def _finalize_boot_steps(self):
-        G = self.graph = DependencyGraph((C.name, C.requires)
-                            for C in self.components.itervalues())
-        last = self._find_last()
-        if last:
-            for obj in G:
-                if obj != last.name:
-                    G.add_edge(last.name, obj)
-        return G.topsort()
-
-    def _claim(self):
-        return self._unclaimed[self.name]
-
-    def _debug(self, msg, *args):
-        return logger.debug('[%s] ' + msg,
-                            *(self.name.capitalize(), ) + args)
-
-
-class ComponentType(type):
-    """Metaclass for components."""
-
-    def __new__(cls, name, bases, attrs):
-        abstract = attrs.pop('abstract', False)
-        if not abstract:
-            try:
-                cname = attrs['name']
-            except KeyError:
-                raise NotImplementedError('Components must be named')
-            namespace = attrs.get('namespace', None)
-            if not namespace:
-                attrs['namespace'], _, attrs['name'] = cname.partition('.')
-        cls = super(ComponentType, cls).__new__(cls, name, bases, attrs)
-        if not abstract:
-            Namespace._unclaimed[cls.namespace][cls.name] = cls
-        return cls
-
-
-class Component(object):
-    """A component.
-
-    The :meth:`__init__` method is called when the component
-    is bound to a parent object, and can as such be used
-    to initialize attributes in the parent object at
-    parent instantiation-time.
-
-    """
-    __metaclass__ = ComponentType
-
-    #: The name of the component, or the namespace
-    #: and the name of the component separated by dot.
-    name = None
-
-    #: List of component names this component depends on.
-    #: Note that the dependencies must be in the same namespace.
-    requires = ()
-
-    #: can be used to specify the namespace,
-    #: if the name does not include it.
-    namespace = None
-
-    #: if set the component will not be registered,
-    #: but can be used as a component base class.
-    abstract = True
-
-    #: Optional obj created by the :meth:`create` method.
-    #: This is used by StartStopComponents to keep the
-    #: original service object.
-    obj = None
-
-    #: This flag is reserved for the workers Consumer,
-    #: since it is required to always be started last.
-    #: There can only be one object marked with lsat
-    #: in every namespace.
-    last = False
-
-    #: This provides the default for :meth:`include_if`.
-    enabled = True
-
-    def __init__(self, parent, **kwargs):
-        pass
-
-    def create(self, parent):
-        """Create the component."""
-        pass
-
-    def include_if(self, parent):
-        """An optional predicate that decided whether this
-        component should be created."""
-        return self.enabled
-
-    def instantiate(self, qualname, *args, **kwargs):
-        return instantiate(qualname, *args, **kwargs)
-
-    def include(self, parent):
-        if self.include_if(parent):
-            self.obj = self.create(parent)
-            return True
-
-
-class StartStopComponent(Component):
-    abstract = True
-
-    def start(self, parent):
-        return self.obj.start()
-
-    def stop(self, parent):
-        return self.obj.stop()
-
-    def close(self, parent):
-        pass
-
-    def terminate(self, parent):
-        self.stop(parent)
-
-    def include(self, parent):
-        if super(StartStopComponent, self).include(parent):
-            parent.components.append(self)

+ 72 - 58
celery/worker/components.py

@@ -15,16 +15,55 @@ from functools import partial
 
 from billiard.exceptions import WorkerLostError
 
+from celery import bootsteps
 from celery.utils.log import worker_logger as logger
 from celery.utils.timer2 import Schedule
 
-from . import bootsteps
+from . import hub
 from .buckets import TaskBucket, FastQueue
-from .hub import Hub, BoundedSemaphore
 
 
-class Pool(bootsteps.StartStopComponent):
-    """The pool component.
+class Hub(bootsteps.StartStopStep):
+
+    def __init__(self, w, **kwargs):
+        w.hub = None
+
+    def include_if(self, w):
+        return w.use_eventloop
+
+    def create(self, w):
+        w.timer = Schedule(max_interval=10)
+        w.hub = hub.Hub(w.timer)
+        return w.hub
+
+
+class Queues(bootsteps.Step):
+    """This step initializes the internal queues
+    used by the worker."""
+    requires = (Hub, )
+
+    def create(self, w):
+        w.start_mediator = True
+        if not w.pool_cls.rlimit_safe:
+            w.disable_rate_limits = True
+        if w.disable_rate_limits:
+            w.ready_queue = FastQueue()
+            if w.use_eventloop:
+                w.start_mediator = False
+                if w.pool_putlocks and w.pool_cls.uses_semaphore:
+                    w.ready_queue.put = w.process_task_sem
+                else:
+                    w.ready_queue.put = w.process_task
+            elif not w.pool_cls.requires_mediator:
+                # just send task directly to pool, skip the mediator.
+                w.ready_queue.put = w.process_task
+                w.start_mediator = False
+        else:
+            w.ready_queue = TaskBucket(task_registry=w.app.tasks)
+
+
+class Pool(bootsteps.StartStopStep):
+    """The pool step.
 
     Describes how to initialize the worker pool, and starts and stops
     the pool during worker startup/shutdown.
@@ -37,8 +76,7 @@ class Pool(bootsteps.StartStopComponent):
         * min_concurrency
 
     """
-    name = 'worker.pool'
-    requires = ('queues', )
+    requires = (Queues, )
 
     def __init__(self, w, autoscale=None, autoreload=None,
             no_execv=False, **kwargs):
@@ -115,7 +153,7 @@ class Pool(bootsteps.StartStopComponent):
         procs = w.min_concurrency
         forking_enable = not threaded or (w.no_execv or not w.force_execv)
         if not threaded:
-            semaphore = w.semaphore = BoundedSemaphore(procs)
+            semaphore = w.semaphore = hub.BoundedSemaphore(procs)
             w._quick_acquire = w.semaphore.acquire
             w._quick_release = w.semaphore.release
             max_restarts = 100
@@ -137,14 +175,13 @@ class Pool(bootsteps.StartStopComponent):
         return pool
 
 
-class Beat(bootsteps.StartStopComponent):
-    """Component used to embed a celerybeat process.
+class Beat(bootsteps.StartStopStep):
+    """Step used to embed a celerybeat process.
 
     This will only be enabled if the ``beat``
     argument is set.
 
     """
-    name = 'worker.beat'
 
     def __init__(self, w, beat=False, **kwargs):
         self.enabled = w.beat = beat
@@ -158,51 +195,9 @@ class Beat(bootsteps.StartStopComponent):
         return b
 
 
-class Queues(bootsteps.Component):
-    """This component initializes the internal queues
-    used by the worker."""
-    name = 'worker.queues'
-    requires = ('ev', )
-
-    def create(self, w):
-        w.start_mediator = True
-        if not w.pool_cls.rlimit_safe:
-            w.disable_rate_limits = True
-        if w.disable_rate_limits:
-            w.ready_queue = FastQueue()
-            if w.use_eventloop:
-                w.start_mediator = False
-                if w.pool_putlocks and w.pool_cls.uses_semaphore:
-                    w.ready_queue.put = w.process_task_sem
-                else:
-                    w.ready_queue.put = w.process_task
-            elif not w.pool_cls.requires_mediator:
-                # just send task directly to pool, skip the mediator.
-                w.ready_queue.put = w.process_task
-                w.start_mediator = False
-        else:
-            w.ready_queue = TaskBucket(task_registry=w.app.tasks)
-
-
-class EvLoop(bootsteps.StartStopComponent):
-    name = 'worker.ev'
-
-    def __init__(self, w, **kwargs):
-        w.hub = None
-
-    def include_if(self, w):
-        return w.use_eventloop
-
-    def create(self, w):
-        w.timer = Schedule(max_interval=10)
-        hub = w.hub = Hub(w.timer)
-        return hub
-
-
-class Timers(bootsteps.Component):
-    """This component initializes the internal timers used by the worker."""
-    name = 'worker.timers'
-    requires = ('pool', )
+class Timers(bootsteps.Step):
+    """This step initializes the internal timers used by the worker."""
+    requires = (Pool, )
 
     def include_if(self, w):
         return not w.use_eventloop
@@ -224,9 +219,8 @@ class Timers(bootsteps.Component):
         logger.debug('Timer wake-up! Next eta %s secs.', delay)
 
 
-class StateDB(bootsteps.Component):
-    """This component sets up the workers state db if enabled."""
-    name = 'worker.state-db'
+class StateDB(bootsteps.Step):
+    """This step sets up the workers state db if enabled."""
 
     def __init__(self, w, **kwargs):
         self.enabled = w.state_db
@@ -235,3 +229,23 @@ class StateDB(bootsteps.Component):
     def create(self, w):
         w._persistence = w.state.Persistent(w.state_db)
         atexit.register(w._persistence.save)
+
+
+class Consumer(bootsteps.StartStopStep):
+    last = True
+
+    def create(self, w):
+        prefetch_count = w.concurrency * w.prefetch_multiplier
+        c = w.consumer = self.instantiate(w.consumer_cls,
+                w.ready_queue,
+                hostname=w.hostname,
+                send_events=w.send_events,
+                init_callback=w.ready_callback,
+                initial_prefetch_count=prefetch_count,
+                pool=w.pool,
+                timer=w.timer,
+                app=w.app,
+                controller=w,
+                hub=w.hub,
+                worker_options=w.options)
+        return c

+ 262 - 549
celery/worker/consumer.py

@@ -3,7 +3,7 @@
 celery.worker.consumer
 ~~~~~~~~~~~~~~~~~~~~~~
 
-This module contains the component responsible for consuming messages
+This module contains the components responsible for consuming messages
 from the broker, processing the messages and keeping the broker connections
 up and running.
 
@@ -12,33 +12,45 @@ from __future__ import absolute_import
 
 import logging
 import socket
-import threading
 
-from time import sleep
-from Queue import Empty
-
-from kombu.common import QoS
+from kombu.common import QoS, ignore_errors
 from kombu.syn import _detect_environment
 from kombu.utils.encoding import safe_repr
-from kombu.utils.eventio import READ, WRITE, ERR
 
+from celery import bootsteps
 from celery.app import app_or_default
-from celery.datastructures import AttributeDict
-from celery.exceptions import InvalidTaskError, SystemTerminate
 from celery.task.trace import build_tracer
-from celery.utils import text
-from celery.utils import timer2
+from celery.utils.timer2 import default_timer, to_timestamp
 from celery.utils.functional import noop
 from celery.utils.log import get_logger
+from celery.utils.text import truncate
 from celery.utils.timeutils import humanize_seconds, timezone
 
-from . import state
-from .bootsteps import StartStopComponent, RUN, CLOSE
-from .control import Panel
-from .heartbeat import Heart
+from . import heartbeat, loops, pidbox
+from .state import task_reserved, maybe_shutdown
+
+CLOSE = bootsteps.CLOSE
+logger = get_logger(__name__)
+debug, info, warn, error, crit = (logger.debug, logger.info, logger.warn,
+                                  logger.error, logger.critical)
+
+CONNECTION_RETRY = """\
+consumer: Connection to broker lost. \
+Trying to re-establish the connection...\
+"""
+
+CONNECTION_RETRY_STEP = """\
+Trying again {when}...\
+"""
 
-#: Heartbeat check is called every heartbeat_seconds' / rate'.
-AMQHEARTBEAT_RATE = 2.0
+CONNECTION_ERROR = """\
+consumer: Cannot connect to %s: %s.
+%s
+"""
+
+CONNECTION_FAILOVER = """\
+Will retry using next failover.\
+"""
 
 UNKNOWN_FORMAT = """\
 Received and deleted unknown message. Wrong destination?!?
@@ -53,7 +65,7 @@ The message has been ignored and discarded.
 
 Did you remember to import the module containing this task?
 Or maybe you are using relative imports?
-More: http://docs.celeryq.org/en/latest/userguide/tasks.html#names
+Please see http://bit.ly/gLye1c for more information.
 
 The full contents of the message body was:
 %s
@@ -64,8 +76,8 @@ INVALID_TASK_ERROR = """\
 Received invalid task message: %s
 The message has been ignored and discarded.
 
-Please ensure your message conforms to the task message format:
-http://docs.celeryq.org/en/latest/internals/protocol.html
+Please ensure your message conforms to the task
+message protocol as described here: http://bit.ly/hYj41y
 
 The full contents of the message body was:
 %s
@@ -76,112 +88,20 @@ body: {0} {{content_type:{1} content_encoding:{2} delivery_info:{3}}}\
 """
 
 
-RETRY_CONNECTION = """\
-consumer: Connection to broker lost. \
-Trying to re-establish the connection...\
-"""
-
-CONNECTION_ERROR = """\
-consumer: Cannot connect to %s: %s.
-%s
-"""
-
-CONNECTION_RETRY = """\
-Trying again {when}...\
-"""
-
-CONNECTION_FAILOVER = """\
-Will retry using next failover.\
-"""
-
-task_reserved = state.task_reserved
-
-logger = get_logger(__name__)
-info, warn, error, crit = (logger.info, logger.warn,
-                           logger.error, logger.critical)
-
-
-def debug(msg, *args, **kwargs):
-    logger.debug('consumer: {0}'.format(msg), *args, **kwargs)
-
-
 def dump_body(m, body):
-    return '{0} ({1}b)'.format(text.truncate(safe_repr(body), 1024),
+    return '{0} ({1}b)'.format(truncate(safe_repr(body), 1024),
                                len(m.body))
 
 
-class Component(StartStopComponent):
-    name = 'worker.consumer'
-    last = True
-
-    def Consumer(self, w):
-        return (w.consumer_cls or
-                Consumer if w.hub else BlockingConsumer)
-
-    def create(self, w):
-        prefetch_count = w.concurrency * w.prefetch_multiplier
-        c = w.consumer = self.instantiate(self.Consumer(w),
-                w.ready_queue,
-                hostname=w.hostname,
-                send_events=w.send_events,
-                init_callback=w.ready_callback,
-                initial_prefetch_count=prefetch_count,
-                pool=w.pool,
-                timer=w.timer,
-                app=w.app,
-                controller=w,
-                hub=w.hub)
-        return c
-
-
 class Consumer(object):
-    """Listen for messages received from the broker and
-    move them to the ready queue for task processing.
-
-    :param ready_queue: See :attr:`ready_queue`.
-    :param timer: See :attr:`timer`.
-
-    """
 
-    #: The queue that holds tasks ready for immediate processing.
+    #: Intra-queue for tasks ready to be handled
     ready_queue = None
 
-    #: Enable/disable events.
-    send_events = False
-
-    #: Optional callback to be called when the connection is established.
-    #: Will only be called once, even if the connection is lost and
-    #: re-established.
+    #: Optional callback called the first time the worker
+    #: is ready to receive tasks.
     init_callback = None
 
-    #: The current hostname.  Defaults to the system hostname.
-    hostname = None
-
-    #: Initial QoS prefetch count for the task channel.
-    initial_prefetch_count = 0
-
-    #: A :class:`celery.events.EventDispatcher` for sending events.
-    event_dispatcher = None
-
-    #: The thread that sends event heartbeats at regular intervals.
-    #: The heartbeats are used by monitors to detect that a worker
-    #: went offline/disappeared.
-    heart = None
-
-    #: The broker connection.
-    connection = None
-
-    #: The consumer used to consume task messages.
-    task_consumer = None
-
-    #: The consumer used to consume broadcast commands.
-    broadcast_consumer = None
-
-    #: The process mailbox (kombu pidbox node).
-    pidbox_node = None
-    _pidbox_node_shutdown = None   # used for greenlets
-    _pidbox_node_stopped = None    # used for greenlets
-
     #: The current worker pool instance.
     pool = None
 
@@ -189,187 +109,188 @@ class Consumer(object):
     #: as sending heartbeats.
     timer = None
 
-    # Consumer state, can be RUN or CLOSE.
-    _state = None
+    class Namespace(bootsteps.Namespace):
+        name = 'Consumer'
+        default_steps = [
+            'celery.worker.consumer:Connection',
+            'celery.worker.consumer:Events',
+            'celery.worker.consumer:Heart',
+            'celery.worker.consumer:Control',
+            'celery.worker.consumer:Tasks',
+            'celery.worker.consumer:Evloop',
+        ]
+
+        def shutdown(self, parent):
+            self.restart(parent, 'Shutdown', 'shutdown')
 
     def __init__(self, ready_queue,
-            init_callback=noop, send_events=False, hostname=None,
-            initial_prefetch_count=2, pool=None, app=None,
+            init_callback=noop, hostname=None,
+            pool=None, app=None,
             timer=None, controller=None, hub=None, amqheartbeat=None,
-            **kwargs):
+            worker_options=None, **kwargs):
         self.app = app_or_default(app)
-        self.connection = None
-        self.task_consumer = None
         self.controller = controller
-        self.broadcast_consumer = None
         self.ready_queue = ready_queue
-        self.send_events = send_events
         self.init_callback = init_callback
         self.hostname = hostname or socket.gethostname()
-        self.initial_prefetch_count = initial_prefetch_count
-        self.event_dispatcher = None
-        self.heart = None
         self.pool = pool
-        self.timer = timer or timer2.default_timer
-        pidbox_state = AttributeDict(app=self.app,
-                                     hostname=self.hostname,
-                                     listener=self,     # pre 2.2
-                                     consumer=self)
-        self.pidbox_node = self.app.control.mailbox.Node(self.hostname,
-                                                         state=pidbox_state,
-                                                         handlers=Panel.data)
+        self.timer = timer or default_timer
+        self.strategies = {}
         conninfo = self.app.connection()
         self.connection_errors = conninfo.connection_errors
         self.channel_errors = conninfo.channel_errors
 
         self._does_info = logger.isEnabledFor(logging.INFO)
-        self.strategies = {}
-        if hub:
-            hub.on_init.append(self.on_poll_init)
-        self.hub = hub
         self._quick_put = self.ready_queue.put
-        self.amqheartbeat = amqheartbeat
-        if self.amqheartbeat is None:
-            self.amqheartbeat = self.app.conf.BROKER_HEARTBEAT
-        if not hub:
+
+        if hub:
+            self.amqheartbeat = amqheartbeat
+            if self.amqheartbeat is None:
+                self.amqheartbeat = self.app.conf.BROKER_HEARTBEAT
+            self.hub = hub
+            self.hub.on_init.append(self.on_poll_init)
+        else:
+            self.hub = None
             self.amqheartbeat = 0
 
+        if not hasattr(self, 'loop'):
+            self.loop = loops.asynloop if hub else loops.synloop
+
         if _detect_environment() == 'gevent':
             # there's a gevent bug that causes timeouts to not be reset,
             # so if the connection timeout is exceeded once, it can NEVER
             # connect again.
             self.app.conf.BROKER_CONNECTION_TIMEOUT = None
 
-    def update_strategies(self):
-        S = self.strategies
-        app = self.app
-        loader = app.loader
-        hostname = self.hostname
-        for name, task in self.app.tasks.iteritems():
-            S[name] = task.start_strategy(app, self)
-            task.__trace__ = build_tracer(name, task, loader, hostname)
+        self.steps = []
+        self.namespace = self.Namespace(
+            app=self.app, on_start=self.on_start, on_close=self.on_close,
+        )
+        self.namespace.apply(self, **worker_options or {})
 
     def start(self):
-        """Start the consumer.
+        ns, loop = self.namespace, self.loop
+        while ns.state != CLOSE:
+            maybe_shutdown()
+            try:
+                ns.start(self)
+            except self.connection_errors + self.channel_errors:
+                maybe_shutdown()
+                if ns.state != CLOSE and self.connection:
+                    error(CONNECTION_RETRY, exc_info=True)
+                    ns.restart(self)
 
-        Automatically survives intermittent connection failure,
-        and will retry establishing the connection and restart
-        consuming messages.
+    def shutdown(self):
+        self.namespace.shutdown(self)
 
-        """
+    def stop(self):
+        self.namespace.stop(self)
 
-        self.init_callback(self)
+    def on_start(self):
+        self.update_strategies()
 
-        while self._state != CLOSE:
-            self.maybe_shutdown()
-            try:
-                self.reset_connection()
-                self.consume_messages()
-            except self.connection_errors + self.channel_errors:
-                error(RETRY_CONNECTION, exc_info=True)
+    def on_ready(self):
+        callback, self.init_callback = self.init_callback, None
+        if callback:
+            callback(self)
+
+    def loop_args(self):
+        return (self, self.connection, self.task_consumer,
+                self.strategies, self.namespace, self.hub, self.qos,
+                self.amqheartbeat, self.handle_unknown_message,
+                self.handle_unknown_task, self.handle_invalid_task)
 
     def on_poll_init(self, hub):
         hub.update_readers(self.connection.eventmap)
         self.connection.transport.on_poll_init(hub.poller)
 
-    def consume_messages(self, sleep=sleep, min=min, Empty=Empty,
-            hbrate=AMQHEARTBEAT_RATE):
-        """Consume messages forever (or until an exception is raised)."""
-
-        with self.hub as hub:
-            qos = self.qos
-            update_qos = qos.update
-            update_readers = hub.update_readers
-            readers, writers = hub.readers, hub.writers
-            poll = hub.poller.poll
-            fire_timers = hub.fire_timers
-            scheduled = hub.timer._queue
-            connection = self.connection
-            hb = self.amqheartbeat
-            hbtick = connection.heartbeat_check
-            on_poll_start = connection.transport.on_poll_start
-            on_poll_empty = connection.transport.on_poll_empty
-            strategies = self.strategies
-            drain_nowait = connection.drain_nowait
-            on_task_callbacks = hub.on_task
-            keep_draining = connection.transport.nb_keep_draining
-
-            if hb and connection.supports_heartbeats:
-                hub.timer.apply_interval(
-                    hb * 1000.0 / hbrate, hbtick, (hbrate, ))
-
-            def on_task_received(body, message):
-                if on_task_callbacks:
-                    [callback() for callback in on_task_callbacks]
-                try:
-                    name = body['task']
-                except (KeyError, TypeError):
-                    return self.handle_unknown_message(body, message)
-                try:
-                    strategies[name](message, body, message.ack_log_error)
-                except KeyError as exc:
-                    self.handle_unknown_task(body, message, exc)
-                except InvalidTaskError as exc:
-                    self.handle_invalid_task(body, message, exc)
-                #fire_timers()
-
-            self.task_consumer.callbacks = [on_task_received]
-            self.task_consumer.consume()
-
-            debug('Ready to accept tasks!')
-
-            while self._state != CLOSE and self.connection:
-                # shutdown if signal handlers told us to.
-                if state.should_stop:
-                    raise SystemExit()
-                elif state.should_terminate:
-                    raise SystemTerminate()
-
-                # fire any ready timers, this also returns
-                # the number of seconds until we need to fire timers again.
-                poll_timeout = fire_timers() if scheduled else 1
-
-                # We only update QoS when there is no more messages to read.
-                # This groups together qos calls, and makes sure that remote
-                # control commands will be prioritized over task messages.
-                if qos.prev != qos.value:
-                    update_qos()
-
-                update_readers(on_poll_start())
-                if readers or writers:
-                    connection.more_to_read = True
-                    while connection.more_to_read:
-                        try:
-                            events = poll(poll_timeout)
-                        except ValueError:  # Issue 882
-                            return
-                        if not events:
-                            on_poll_empty()
-                        for fileno, event in events or ():
-                            try:
-                                if event & READ:
-                                    readers[fileno](fileno, event)
-                                if event & WRITE:
-                                    writers[fileno](fileno, event)
-                                if event & ERR:
-                                    for handlermap in readers, writers:
-                                        try:
-                                            handlermap[fileno](fileno, event)
-                                        except KeyError:
-                                            pass
-                            except (KeyError, Empty):
-                                continue
-                            except socket.error:
-                                if self._state != CLOSE:  # pragma: no cover
-                                    raise
-                        if keep_draining:
-                            drain_nowait()
-                            poll_timeout = 0
-                        else:
-                            connection.more_to_read = False
-                else:
-                    # no sockets yet, startup is probably not done.
-                    sleep(min(poll_timeout, 0.1))
+    def on_decode_error(self, message, exc):
+        """Callback called if an error occurs while decoding
+        a message received.
+
+        Simply logs the error and acknowledges the message so it
+        doesn't enter a loop.
+
+        :param message: The message with errors.
+        :param exc: The original exception instance.
+
+        """
+        crit("Can't decode message body: %r (type:%r encoding:%r raw:%r')",
+             exc, message.content_type, message.content_encoding,
+             dump_body(message, message.body))
+        message.ack()
+
+    def on_close(self):
+        # Clear internal queues to get rid of old messages.
+        # They can't be acked anyway, as a delivery tag is specific
+        # to the current channel.
+        self.ready_queue.clear()
+        self.timer.clear()
+
+    def connect(self):
+        """Establish the broker connection.
+
+        Will retry establishing the connection if the
+        :setting:`BROKER_CONNECTION_RETRY` setting is enabled
+
+        """
+        conn = self.app.connection(heartbeat=self.amqheartbeat)
+
+        # Callback called for each retry while the connection
+        # can't be established.
+        def _error_handler(exc, interval, next_step=CONNECTION_RETRY_STEP):
+            if getattr(conn, 'alt', None) and interval == 0:
+                next_step = CONNECTION_FAILOVER
+            error(CONNECTION_ERROR, conn.as_uri(), exc,
+                  next_step.format(when=humanize_seconds(interval, 'in', ' ')))
+
+        # remember that the connection is lazy, it won't establish
+        # until it's needed.
+        if not self.app.conf.BROKER_CONNECTION_RETRY:
+            # retry disabled, just call connect directly.
+            conn.connect()
+            return conn
+
+        return conn.ensure_connection(_error_handler,
+                    self.app.conf.BROKER_CONNECTION_MAX_RETRIES,
+                    callback=maybe_shutdown)
+
+    def add_task_queue(self, queue, exchange=None, exchange_type=None,
+            routing_key=None, **options):
+        cset = self.task_consumer
+        try:
+            q = self.app.amqp.queues[queue]
+        except KeyError:
+            exchange = queue if exchange is None else exchange
+            exchange_type = 'direct' if exchange_type is None \
+                                     else exchange_type
+            q = self.app.amqp.queues.select_add(queue,
+                    exchange=exchange,
+                    exchange_type=exchange_type,
+                    routing_key=routing_key, **options)
+        if not cset.consuming_from(queue):
+            cset.add_queue(q)
+            cset.consume()
+            info('Started consuming from %r', queue)
+
+    def cancel_task_queue(self, queue):
+        self.app.amqp.queues.select_remove(queue)
+        self.task_consumer.cancel_by_queue(queue)
+
+    @property
+    def info(self):
+        """Returns information about this consumer instance
+        as a dict.
+
+        This is also the consumer related info returned by
+        ``celeryctl stats``.
+
+        """
+        conninfo = {}
+        if self.connection:
+            conninfo = self.connection.info()
+            conninfo.pop('password', None)  # don't send password.
+        return {'broker': conninfo, 'prefetch_count': self.qos.value}
 
     def on_task(self, task, task_reserved=task_reserved):
         """Handle received task.
@@ -395,7 +316,7 @@ class Consumer(object):
         if task.eta:
             eta = timezone.to_system(task.eta) if task.utc else task.eta
             try:
-                eta = timer2.to_timestamp(eta)
+                eta = to_timestamp(eta)
             except OverflowError as exc:
                 error("Couldn't convert eta %s to timestamp: %r. Task: %r",
                       task.eta, exc, task.info(safe=True), exc_info=True)
@@ -409,16 +330,6 @@ class Consumer(object):
             task_reserved(task)
             self._quick_put(task)
 
-    def on_control(self, body, message):
-        """Process remote control command message."""
-        try:
-            self.pidbox_node.handle_message(body, message)
-        except KeyError as exc:
-            error('No such control command: %s', exc)
-        except Exception as exc:
-            error('Control command error: %r', exc, exc_info=True)
-            self.reset_pidbox_node()
-
     def apply_eta_task(self, task):
         """Method called by the timer to apply a task with an
         ETA/countdown."""
@@ -444,307 +355,109 @@ class Consumer(object):
         error(INVALID_TASK_ERROR, exc, dump_body(message, body), exc_info=True)
         message.reject_log_error(logger, self.connection_errors)
 
-    def receive_message(self, body, message):
-        """Handles incoming messages.
+    def update_strategies(self):
+        loader = self.app.loader
+        for name, task in self.app.tasks.iteritems():
+            self.strategies[name] = task.start_strategy(self.app, self)
+            task.__trace__ = build_tracer(name, task, loader, self.hostname)
 
-        :param body: The message body.
-        :param message: The kombu message object.
 
-        """
-        try:
-            name = body['task']
-        except (KeyError, TypeError):
-            return self.handle_unknown_message(body, message)
+class Connection(bootsteps.StartStopStep):
 
-        try:
-            self.strategies[name](message, body, message.ack_log_error)
-        except KeyError as exc:
-            self.handle_unknown_task(body, message, exc)
-        except InvalidTaskError as exc:
-            self.handle_invalid_task(body, message, exc)
-
-    def maybe_conn_error(self, fun):
-        """Applies function but ignores any connection or channel
-        errors raised."""
-        try:
-            fun()
-        except (AttributeError, ) + \
-                self.connection_errors + \
-                self.channel_errors:
-            pass
+    def __init__(self, c, **kwargs):
+        c.connection = None
 
-    def close_connection(self):
-        """Closes the current broker connection and all open channels."""
+    def start(self, c):
+        c.connection = c.connect()
+        info('Connected to %s', c.connection.as_uri())
 
+    def shutdown(self, c):
         # We must set self.connection to None here, so
         # that the green pidbox thread exits.
-        connection, self.connection = self.connection, None
-
-        if self.task_consumer:
-            debug('Closing consumer channel...')
-            self.task_consumer = \
-                    self.maybe_conn_error(self.task_consumer.close)
-
-        self.stop_pidbox_node()
-
+        connection, c.connection = c.connection, None
         if connection:
-            debug('Closing broker connection...')
-            self.maybe_conn_error(connection.close)
-
-    def stop_consumers(self, close_connection=True, join=True):
-        """Stop consuming tasks and broadcast commands, also stops
-        the heartbeat thread and event dispatcher.
+            ignore_errors(connection, connection.close)
 
-        :keyword close_connection: Set to False to skip closing the broker
-                                    connection.
-
-        """
-        if not self._state == RUN:
-            return
-
-        if self.heart:
-            # Stop the heartbeat thread if it's running.
-            debug('Heart: Going into cardiac arrest...')
-            self.heart = self.heart.stop()
-
-        debug('Cancelling task consumer...')
-        if join and self.task_consumer:
-            self.maybe_conn_error(self.task_consumer.cancel)
-
-        if self.event_dispatcher:
-            debug('Shutting down event dispatcher...')
-            self.event_dispatcher = \
-                    self.maybe_conn_error(self.event_dispatcher.close)
-
-        debug('Cancelling broadcast consumer...')
-        if join and self.broadcast_consumer:
-            self.maybe_conn_error(self.broadcast_consumer.cancel)
-
-        if close_connection:
-            self.close_connection()
-
-    def on_decode_error(self, message, exc):
-        """Callback called if an error occurs while decoding
-        a message received.
-
-        Simply logs the error and acknowledges the message so it
-        doesn't enter a loop.
-
-        :param message: The message with errors.
-        :param exc: The original exception instance.
-
-        """
-        crit("Can't decode message body: %r (type:%r encoding:%r raw:%r')",
-             exc, message.content_type, message.content_encoding,
-             dump_body(message, message.body))
-        message.ack()
-
-    def reset_pidbox_node(self):
-        """Sets up the process mailbox."""
-        self.stop_pidbox_node()
-        # close previously opened channel if any.
-        if self.pidbox_node.channel:
-            try:
-                self.pidbox_node.channel.close()
-            except self.connection_errors + self.channel_errors:
-                pass
-
-        if self.pool is not None and self.pool.is_green:
-            return self.pool.spawn_n(self._green_pidbox_node)
-        self.pidbox_node.channel = self.connection.channel()
-        self.broadcast_consumer = self.pidbox_node.listen(
-                                        callback=self.on_control)
-
-    def stop_pidbox_node(self):
-        if self._pidbox_node_stopped:
-            self._pidbox_node_shutdown.set()
-            debug('Waiting for broadcast thread to shutdown...')
-            self._pidbox_node_stopped.wait()
-            self._pidbox_node_stopped = self._pidbox_node_shutdown = None
-        elif self.broadcast_consumer:
-            debug('Closing broadcast channel...')
-            self.broadcast_consumer = \
-                self.maybe_conn_error(self.broadcast_consumer.channel.close)
-
-    def _green_pidbox_node(self):
-        """Sets up the process mailbox when running in a greenlet
-        environment."""
-        # THIS CODE IS TERRIBLE
-        # Luckily work has already started rewriting the Consumer for 4.0.
-        self._pidbox_node_shutdown = threading.Event()
-        self._pidbox_node_stopped = threading.Event()
-        try:
-            with self._open_connection() as conn:
-                info('pidbox: Connected to %s.', conn.as_uri())
-                self.pidbox_node.channel = conn.default_channel
-                self.broadcast_consumer = self.pidbox_node.listen(
-                                            callback=self.on_control)
-                with self.broadcast_consumer:
-                    while not self._pidbox_node_shutdown.isSet():
-                        try:
-                            conn.drain_events(timeout=1.0)
-                        except socket.timeout:
-                            pass
-        finally:
-            self._pidbox_node_stopped.set()
-
-    def reset_connection(self):
-        """Re-establish the broker connection and set up consumers,
-        heartbeat and the event dispatcher."""
-        debug('Re-establishing connection to the broker...')
-        self.stop_consumers(join=False)
-
-        # Clear internal queues to get rid of old messages.
-        # They can't be acked anyway, as a delivery tag is specific
-        # to the current channel.
-        self.ready_queue.clear()
-        self.timer.clear()
 
-        # Re-establish the broker connection and setup the task consumer.
-        self.connection = self._open_connection()
-        info('consumer: Connected to %s.', self.connection.as_uri())
-        self.task_consumer = self.app.amqp.TaskConsumer(self.connection,
-                                    on_decode_error=self.on_decode_error)
-        # QoS: Reset prefetch window.
-        self.qos = QoS(self.task_consumer, self.initial_prefetch_count)
-        self.qos.update()
+class Events(bootsteps.StartStopStep):
+    requires = (Connection, )
 
-        # Setup the process mailbox.
-        self.reset_pidbox_node()
+    def __init__(self, c, send_events=None, **kwargs):
+        self.send_events = send_events
+        c.event_dispatcher = None
 
+    def start(self, c):
         # Flush events sent while connection was down.
-        prev_event_dispatcher = self.event_dispatcher
-        self.event_dispatcher = self.app.events.Dispatcher(self.connection,
-                                                hostname=self.hostname,
-                                                enabled=self.send_events)
-        if prev_event_dispatcher:
-            self.event_dispatcher.copy_buffer(prev_event_dispatcher)
-            self.event_dispatcher.flush()
-
-        # Restart heartbeat thread.
-        self.restart_heartbeat()
-
-        # reload all task's execution strategies.
-        self.update_strategies()
-
-        # We're back!
-        self._state = RUN
-
-    def restart_heartbeat(self):
-        """Restart the heartbeat thread.
-
-        This thread sends heartbeat events at intervals so monitors
-        can tell if the worker is off-line/missing.
-
-        """
-        self.heart = Heart(self.timer, self.event_dispatcher)
-        self.heart.start()
-
-    def _open_connection(self):
-        """Establish the broker connection.
-
-        Will retry establishing the connection if the
-        :setting:`BROKER_CONNECTION_RETRY` setting is enabled
-
-        """
-        conn = self.app.connection(heartbeat=self.amqheartbeat)
-
-        # Callback called for each retry while the connection
-        # can't be established.
-        def _error_handler(exc, interval, next_step=CONNECTION_RETRY):
-            if getattr(conn, 'alt', None) and interval == 0:
-                next_step = CONNECTION_FAILOVER
-            error(CONNECTION_ERROR, conn.as_uri(), exc,
-                  next_step.format(when=humanize_seconds(interval, 'in', ' ')))
+        prev = c.event_dispatcher
+        dis = c.event_dispatcher = c.app.events.Dispatcher(
+            c.connection, hostname=c.hostname, enabled=self.send_events,
+        )
+        if prev:
+            dis.copy_buffer(prev)
+            dis.flush()
 
-        # remember that the connection is lazy, it won't establish
-        # until it's needed.
-        if not self.app.conf.BROKER_CONNECTION_RETRY:
-            # retry disabled, just call connect directly.
-            conn.connect()
-            return conn
-
-        return conn.ensure_connection(_error_handler,
-                    self.app.conf.BROKER_CONNECTION_MAX_RETRIES,
-                    callback=self.maybe_shutdown)
+    def stop(self, c):
+        if c.event_dispatcher:
+            ignore_errors(c, c.event_dispatcher.close)
+            c.event_dispatcher = None
+    shutdown = stop
 
-    def stop(self):
-        """Stop consuming.
 
-        Does not close the broker connection, so be sure to call
-        :meth:`close_connection` when you are finished with it.
+class Heart(bootsteps.StartStopStep):
+    requires = (Events, )
 
-        """
-        # Notifies other threads that this instance can't be used
-        # anymore.
-        self.close()
-        debug('Stopping consumers...')
-        self.stop_consumers(close_connection=False, join=True)
+    def __init__(self, c, **kwargs):
+        c.heart = None
 
-    def close(self):
-        self._state = CLOSE
+    def start(self, c):
+        c.heart = heartbeat.Heart(c.timer, c.event_dispatcher)
+        c.heart.start()
 
-    def maybe_shutdown(self):
-        if state.should_stop:
-            raise SystemExit()
-        elif state.should_terminate:
-            raise SystemTerminate()
+    def stop(self, c):
+        c.heart = c.heart and c.heart.stop()
+    shutdown = stop
 
-    def add_task_queue(self, queue, exchange=None, exchange_type=None,
-            routing_key=None, **options):
-        cset = self.task_consumer
-        try:
-            q = self.app.amqp.queues[queue]
-        except KeyError:
-            exchange = queue if exchange is None else exchange
-            exchange_type = 'direct' if exchange_type is None \
-                                     else exchange_type
-            q = self.app.amqp.queues.select_add(queue,
-                    exchange=exchange,
-                    exchange_type=exchange_type,
-                    routing_key=routing_key, **options)
-        if not cset.consuming_from(queue):
-            cset.add_queue(q)
-            cset.consume()
-            logger.info('Started consuming from %r', queue)
 
-    def cancel_task_queue(self, queue):
-        self.app.amqp.queues.select_remove(queue)
-        self.task_consumer.cancel_by_queue(queue)
+class Control(bootsteps.StartStopStep):
+    requires = (Events, )
 
-    @property
-    def info(self):
-        """Returns information about this consumer instance
-        as a dict.
+    def __init__(self, c, **kwargs):
+        self.is_green = c.pool is not None and c.pool.is_green
+        self.box = (pidbox.gPidbox if self.is_green else pidbox.Pidbox)(c)
+        self.start = self.box.start
+        self.stop = self.box.stop
+        self.shutdown = self.box.shutdown
 
-        This is also the consumer related info returned by
-        ``celeryctl stats``.
 
-        """
-        conninfo = {}
-        if self.connection:
-            conninfo = self.connection.info()
-            conninfo.pop('password', None)  # don't send password.
-        return {'broker': conninfo, 'prefetch_count': self.qos.value}
+class Tasks(bootsteps.StartStopStep):
+    requires = (Control, )
 
+    def __init__(self, c, initial_prefetch_count=2, **kwargs):
+        c.task_consumer = c.qos = None
+        self.initial_prefetch_count = initial_prefetch_count
 
-class BlockingConsumer(Consumer):
+    def start(self, c):
+        c.task_consumer = c.app.amqp.TaskConsumer(
+            c.connection, on_decode_error=c.on_decode_error,
+        )
+        c.qos = QoS(c.task_consumer, self.initial_prefetch_count)
+        c.qos.update()  # set initial prefetch count
+
+    def stop(self, c):
+        if c.task_consumer:
+            debug('Cancelling task consumer...')
+            ignore_errors(c, c.task_consumer.cancel)
+
+    def shutdown(self, c):
+        if c.task_consumer:
+            self.stop(c)
+            debug('Closing consumer channel...')
+            ignore_errors(c, c.task_consumer.close)
+            c.task_consumer = None
 
-    def consume_messages(self):
-        # receive_message handles incoming messages.
-        self.task_consumer.register_callback(self.receive_message)
-        self.task_consumer.consume()
 
-        debug('Ready to accept tasks!')
+class Evloop(bootsteps.StartStopStep):
+    last = True
 
-        while self._state != CLOSE and self.connection:
-            self.maybe_shutdown()
-            if self.qos.prev != self.qos.value:     # pragma: no cover
-                self.qos.update()
-            try:
-                self.connection.drain_events(timeout=10.0)
-            except socket.timeout:
-                pass
-            except socket.error:
-                if self._state != CLOSE:            # pragma: no cover
-                    raise
+    def start(self, c):
+        c.loop(*c.loop_args())

+ 2 - 2
celery/worker/hub.py

@@ -132,11 +132,11 @@ class Hub(object):
         self.on_task = []
 
     def start(self):
-        """Called by StartStopComponent at worker startup."""
+        """Called by Hub bootstep at worker startup."""
         self.poller = eventio.poll()
 
     def stop(self):
-        """Called by StartStopComponent at worker shutdown."""
+        """Called by Hub bootstep at worker shutdown."""
         self.poller.close()
 
     def init(self):

+ 156 - 0
celery/worker/loops.py

@@ -0,0 +1,156 @@
+"""
+celery.worker.loop
+~~~~~~~~~~~~~~~~~~
+
+The consumers highly-optimized inner loop.
+
+"""
+from __future__ import absolute_import
+
+import socket
+
+from time import sleep
+from Queue import Empty
+
+from kombu.utils.eventio import READ, WRITE, ERR
+
+from celery.bootsteps import CLOSE
+from celery.exceptions import InvalidTaskError, SystemTerminate
+
+from . import state
+
+#: Heartbeat check is called every heartbeat_seconds' / rate'.
+AMQHEARTBEAT_RATE = 2.0
+
+
+def asynloop(obj, connection, consumer, strategies, ns, hub, qos,
+        heartbeat, handle_unknown_message, handle_unknown_task,
+        handle_invalid_task, sleep=sleep, min=min, Empty=Empty,
+        hbrate=AMQHEARTBEAT_RATE):
+    """Non-blocking eventloop consuming messages until connection is lost,
+    or shutdown is requested."""
+
+    with hub as hub:
+        update_qos = qos.update
+        update_readers = hub.update_readers
+        readers, writers = hub.readers, hub.writers
+        poll = hub.poller.poll
+        fire_timers = hub.fire_timers
+        scheduled = hub.timer._queue
+        hbtick = connection.heartbeat_check
+        on_poll_start = connection.transport.on_poll_start
+        on_poll_empty = connection.transport.on_poll_empty
+        drain_nowait = connection.drain_nowait
+        on_task_callbacks = hub.on_task
+        keep_draining = connection.transport.nb_keep_draining
+
+        if heartbeat and connection.supports_heartbeats:
+            hub.timer.apply_interval(
+                heartbeat * 1000.0 / hbrate, hbtick, (hbrate, ))
+
+        def on_task_received(body, message):
+            if on_task_callbacks:
+                [callback() for callback in on_task_callbacks]
+            try:
+                name = body['task']
+            except (KeyError, TypeError):
+                return handle_unknown_message(body, message)
+            try:
+                strategies[name](message, body, message.ack_log_error)
+            except KeyError as exc:
+                handle_unknown_task(body, message, exc)
+            except InvalidTaskError as exc:
+                handle_invalid_task(body, message, exc)
+
+        consumer.callbacks = [on_task_received]
+        consumer.consume()
+        obj.on_ready()
+
+        while ns.state != CLOSE and obj.connection:
+            # shutdown if signal handlers told us to.
+            if state.should_stop:
+                raise SystemExit()
+            elif state.should_terminate:
+                raise SystemTerminate()
+
+            # fire any ready timers, this also returns
+            # the number of seconds until we need to fire timers again.
+            poll_timeout = fire_timers() if scheduled else 1
+
+            # We only update QoS when there is no more messages to read.
+            # This groups together qos calls, and makes sure that remote
+            # control commands will be prioritized over task messages.
+            if qos.prev != qos.value:
+                update_qos()
+
+            update_readers(on_poll_start())
+            if readers or writers:
+                connection.more_to_read = True
+                while connection.more_to_read:
+                    try:
+                        events = poll(poll_timeout)
+                    except ValueError:  # Issue 882
+                        return
+                    if not events:
+                        on_poll_empty()
+                    for fileno, event in events or ():
+                        try:
+                            if event & READ:
+                                readers[fileno](fileno, event)
+                            if event & WRITE:
+                                writers[fileno](fileno, event)
+                            if event & ERR:
+                                for handlermap in readers, writers:
+                                    try:
+                                        handlermap[fileno](fileno, event)
+                                    except KeyError:
+                                        pass
+                        except (KeyError, Empty):
+                            continue
+                        except socket.error:
+                            if ns.state != CLOSE:  # pragma: no cover
+                                raise
+                    if keep_draining:
+                        drain_nowait()
+                        poll_timeout = 0
+                    else:
+                        connection.more_to_read = False
+            else:
+                # no sockets yet, startup is probably not done.
+                sleep(min(poll_timeout, 0.1))
+
+
+def synloop(obj, connection, consumer, strategies, ns, hub, qos,
+        heartbeat, handle_unknown_message, handle_unknown_task,
+        handle_invalid_task, **kwargs):
+    """Fallback blocking eventloop for transports that doesn't support AIO."""
+
+    def on_task_received(body, message):
+        try:
+            name = body['task']
+        except (KeyError, TypeError):
+            return handle_unknown_message(body, message)
+
+        try:
+            strategies[name](message, body, message.ack_log_error)
+        except KeyError as exc:
+            handle_unknown_task(body, message, exc)
+        except InvalidTaskError as exc:
+            handle_invalid_task(body, message, exc)
+
+    consumer.register_callback(on_task_received)
+    consumer.consume()
+
+    obj.on_ready()
+
+    while ns.state != CLOSE and obj.connection:
+        state.maybe_shutdown()
+        if qos.prev != qos.value:         # pragma: no cover
+            qos.update()
+        try:
+            connection.drain_events(timeout=2.0)
+        except socket.timeout:
+            pass
+        except socket.error:
+            if ns.state != CLOSE:  # pragma: no cover
+                raise

+ 4 - 4
celery/worker/mediator.py

@@ -20,17 +20,17 @@ import logging
 from Queue import Empty
 
 from celery.app import app_or_default
+from celery.bootsteps import StartStopStep
 from celery.utils.threads import bgThread
 from celery.utils.log import get_logger
 
-from .bootsteps import StartStopComponent
+from . import components
 
 logger = get_logger(__name__)
 
 
-class WorkerComponent(StartStopComponent):
-    name = 'worker.mediator'
-    requires = ('pool', 'queues', )
+class WorkerComponent(StartStopStep):
+    requires = (components.Pool, components.Queues, )
 
     def __init__(self, w, **kwargs):
         w.mediator = None

+ 103 - 0
celery/worker/pidbox.py

@@ -0,0 +1,103 @@
+from __future__ import absolute_import
+
+import socket
+import threading
+
+from kombu.common import ignore_errors
+
+from celery.datastructures import AttributeDict
+from celery.utils.log import get_logger
+
+from . import control
+
+logger = get_logger(__name__)
+debug, error, info = logger.debug, logger.error, logger.info
+
+
+class Pidbox(object):
+    consumer = None
+
+    def __init__(self, c):
+        self.c = c
+        self.hostname = c.hostname
+        self.node = c.app.control.mailbox.Node(c.hostname,
+            handlers=control.Panel.data,
+            state=AttributeDict(app=c.app, hostname=c.hostname, consumer=c),
+        )
+
+    def on_message(self, body, message):
+        try:
+            self.node.handle_message(body, message)
+        except KeyError as exc:
+            error('No such control command: %s', exc)
+        except Exception as exc:
+            error('Control command error: %r', exc, exc_info=True)
+            self.reset()
+
+    def start(self, c):
+        self.node.channel = c.connection.channel()
+        self.consumer = self.node.listen(callback=self.on_message)
+
+    def stop(self, c):
+        self.consumer = self._close_channel(c)
+
+    def reset(self):
+        """Sets up the process mailbox."""
+        self.stop(self.c)
+        self.start(self.c)
+
+    def _close_channel(self, c):
+        if self.node and self.node.channel:
+            ignore_errors(c, self.node.channel.close)
+
+    def shutdown(self, c):
+        if self.consumer:
+            debug('Cancelling broadcast consumer...')
+            ignore_errors(c, self.consumer.cancel)
+        self.stop(self.c)
+
+
+class gPidbox(Pidbox):
+    _node_shutdown = None
+    _node_stopped = None
+    _resets = 0
+
+    def start(self, c):
+        c.pool.spawn_n(self.loop, c)
+
+    def stop(self, c):
+        if self._node_stopped:
+            self._node_shutdown.set()
+            debug('Waiting for broadcast thread to shutdown...')
+            self._node_stopped.wait()
+            self._node_stopped = self._node_shutdown = None
+        super(gPidbox, self).stop(c)
+
+    def reset(self):
+        self._resets += 1
+
+    def _do_reset(self, c, connection):
+        self._close_channel(c)
+        self.node.channel = connection.channel()
+        self.consumer = self.node.listen(callback=self.on_message)
+        self.consumer.consume()
+
+    def loop(self, c):
+        resets = [self._resets]
+        shutdown = self._node_shutdown = threading.Event()
+        stopped = self._node_stopped = threading.Event()
+        try:
+            with c.connect() as connection:
+
+                info('pidbox: Connected to %s.', connection.as_uri())
+                self._do_reset(c, connection)
+                while not shutdown.is_set() and c.connection:
+                    if resets[0] < self._resets:
+                        resets[0] += 1
+                        self._do_reset(c, connection)
+                    try:
+                        connection.drain_events(timeout=1.0)
+                    except socket.timeout:
+                        pass
+        finally:
+            stopped.set()

+ 8 - 0
celery/worker/state.py

@@ -20,6 +20,7 @@ from collections import defaultdict
 from kombu.utils import cached_property
 
 from celery import __version__
+from celery.exceptions import SystemTerminate
 from celery.datastructures import LimitedSet
 
 #: Worker software/platform information.
@@ -53,6 +54,13 @@ should_stop = False
 should_terminate = False
 
 
+def maybe_shutdown():
+    if should_stop:
+        raise SystemExit()
+    elif should_terminate:
+        raise SystemTerminate()
+
+
 def task_accepted(request):
     """Updates global state when a task has been accepted."""
     active_requests.add(request)

+ 12 - 1
docs/configuration.rst

@@ -1408,9 +1408,20 @@ CELERYD_BOOT_STEPS
 ~~~~~~~~~~~~~~~~~~
 
 This setting enables you to add additional components to the worker process.
-It should be a list of module names with :class:`celery.abstract.Component`
+It should be a list of module names with
+:class:`celery.bootsteps.Step`
 classes, that augments functionality in the worker.
 
+.. setting:: CELERYD_CONSUMER_BOOT_STEPS
+
+CELERYD_CONSUMER_BOOT_STEPS
+~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+This setting enables you to add additional components to the workers consumer.
+It should be a list of module names with
+:class:`celery.bootsteps.Step`` classes, that augments
+functionality in the consumer.
+
 .. setting:: CELERYD_POOL
 
 CELERYD_POOL

+ 0 - 1
docs/internals/reference/index.rst

@@ -21,7 +21,6 @@
     celery.worker.strategy
     celery.worker.autoreload
     celery.worker.autoscale
-    celery.worker.bootsteps
     celery.concurrency
     celery.concurrency.solo
     celery.concurrency.processes

+ 3 - 3
docs/internals/reference/celery.worker.bootsteps.rst → docs/reference/celery.bootsteps.rst

@@ -1,11 +1,11 @@
 ==========================================
- celery.worker.bootsteps
+ celery.bootsteps
 ==========================================
 
 .. contents::
     :local:
-.. currentmodule:: celery.worker.bootsteps
+.. currentmodule:: celery.bootsteps
 
-.. automodule:: celery.worker.bootsteps
+.. automodule:: celery.bootsteps
     :members:
     :undoc-members: