Ver Fonte

Merge branch 'bootsteps_refactor'

Ask Solem há 12 anos atrás
pai
commit
f70c4aacc7

+ 3 - 1
celery/app/base.py

@@ -11,7 +11,7 @@ from __future__ import absolute_import
 import threading
 import threading
 import warnings
 import warnings
 
 
-from collections import deque
+from collections import defaultdict, deque
 from contextlib import contextmanager
 from contextlib import contextmanager
 from copy import deepcopy
 from copy import deepcopy
 from functools import wraps
 from functools import wraps
@@ -72,6 +72,8 @@ class Celery(object):
         self.set_as_current = set_as_current
         self.set_as_current = set_as_current
         self.registry_cls = symbol_by_name(self.registry_cls)
         self.registry_cls = symbol_by_name(self.registry_cls)
         self.accept_magic_kwargs = accept_magic_kwargs
         self.accept_magic_kwargs = accept_magic_kwargs
+        self.user_options = defaultdict(set)
+        self.steps = defaultdict(set)
 
 
         self.configured = False
         self.configured = False
         self._pending_defaults = deque()
         self._pending_defaults = deque()

+ 6 - 5
celery/app/defaults.py

@@ -150,22 +150,23 @@ NAMESPACES = {
         'WORKER_DIRECT': Option(False, type='bool'),
         'WORKER_DIRECT': Option(False, type='bool'),
     },
     },
     'CELERYD': {
     'CELERYD': {
-        'AUTOSCALER': Option('celery.worker.autoscale.Autoscaler'),
-        'AUTORELOADER': Option('celery.worker.autoreload.Autoreloader'),
+        'AUTOSCALER': Option('celery.worker.autoscale:Autoscaler'),
+        'AUTORELOADER': Option('celery.worker.autoreload:Autoreloader'),
         'BOOT_STEPS': Option((), type='tuple'),
         'BOOT_STEPS': Option((), type='tuple'),
+        'CONSUMER_BOOT_STEPS': Option((), type='tuple'),
         'CONCURRENCY': Option(0, type='int'),
         'CONCURRENCY': Option(0, type='int'),
         'TIMER': Option(type='string'),
         'TIMER': Option(type='string'),
         'TIMER_PRECISION': Option(1.0, type='float'),
         'TIMER_PRECISION': Option(1.0, type='float'),
         'FORCE_EXECV': Option(True, type='bool'),
         'FORCE_EXECV': Option(True, type='bool'),
         'HIJACK_ROOT_LOGGER': Option(True, type='bool'),
         'HIJACK_ROOT_LOGGER': Option(True, type='bool'),
-        'CONSUMER': Option(type='string'),
+        'CONSUMER': Option('celery.worker.consumer:Consumer', type='string'),
         'LOG_FORMAT': Option(DEFAULT_PROCESS_LOG_FMT),
         'LOG_FORMAT': Option(DEFAULT_PROCESS_LOG_FMT),
         'LOG_COLOR': Option(type='bool'),
         'LOG_COLOR': Option(type='bool'),
         'LOG_LEVEL': Option('WARN', deprecate_by='2.4', remove_by='4.0',
         'LOG_LEVEL': Option('WARN', deprecate_by='2.4', remove_by='4.0',
                             alt='--loglevel argument'),
                             alt='--loglevel argument'),
         'LOG_FILE': Option(deprecate_by='2.4', remove_by='4.0',
         'LOG_FILE': Option(deprecate_by='2.4', remove_by='4.0',
                             alt='--logfile argument'),
                             alt='--logfile argument'),
-        'MEDIATOR': Option('celery.worker.mediator.Mediator'),
+        'MEDIATOR': Option('celery.worker.mediator:Mediator'),
         'MAX_TASKS_PER_CHILD': Option(type='int'),
         'MAX_TASKS_PER_CHILD': Option(type='int'),
         'POOL': Option(DEFAULT_POOL),
         'POOL': Option(DEFAULT_POOL),
         'POOL_PUTLOCKS': Option(True, type='bool'),
         'POOL_PUTLOCKS': Option(True, type='bool'),
@@ -179,7 +180,7 @@ NAMESPACES = {
     },
     },
     'CELERYBEAT': {
     'CELERYBEAT': {
         'SCHEDULE': Option({}, type='dict'),
         'SCHEDULE': Option({}, type='dict'),
-        'SCHEDULER': Option('celery.beat.PersistentScheduler'),
+        'SCHEDULER': Option('celery.beat:PersistentScheduler'),
         'SCHEDULE_FILENAME': Option('celerybeat-schedule'),
         'SCHEDULE_FILENAME': Option('celerybeat-schedule'),
         'MAX_LOOP_INTERVAL': Option(0, type='float'),
         'MAX_LOOP_INTERVAL': Option(0, type='float'),
         'LOG_LEVEL': Option('INFO', deprecate_by='2.4', remove_by='4.0',
         'LOG_LEVEL': Option('INFO', deprecate_by='2.4', remove_by='4.0',

+ 8 - 3
celery/apps/worker.py

@@ -101,6 +101,10 @@ class Worker(WorkController):
             enabled=not no_color if no_color is not None else no_color
             enabled=not no_color if no_color is not None else no_color
         )
         )
 
 
+    def on_init_namespace(self):
+        print('SETUP LOGGING: %r' % (self.redirect_stdouts, ))
+        self.setup_logging()
+
     def on_start(self):
     def on_start(self):
         WorkController.on_start(self)
         WorkController.on_start(self)
 
 
@@ -122,10 +126,11 @@ class Worker(WorkController):
 
 
         # Dump configuration to screen so we have some basic information
         # Dump configuration to screen so we have some basic information
         # for when users sends bug reports.
         # for when users sends bug reports.
-        print(str(self.colored.cyan(' \n', self.startup_info())) +
-              str(self.colored.reset(self.extra_info() or '')))
+        sys.__stdout__.write(
+            str(self.colored.cyan(' \n', self.startup_info())) +
+            str(self.colored.reset(self.extra_info() or '')) + '\n'
+        )
         self.set_process_status('-active-')
         self.set_process_status('-active-')
-        self.setup_logging()
         self.install_platform_tweaks(self)
         self.install_platform_tweaks(self)
 
 
     def on_consumer_ready(self, consumer):
     def on_consumer_ready(self, consumer):

+ 5 - 0
celery/bin/__init__.py

@@ -0,0 +1,5 @@
+from __future__ import absolute_import
+
+from collections import defaultdict
+
+from .base import Option  # noqa

+ 33 - 1
celery/bin/celery.py

@@ -10,6 +10,7 @@ from __future__ import absolute_import, print_function
 
 
 import anyjson
 import anyjson
 import heapq
 import heapq
+import os
 import sys
 import sys
 import warnings
 import warnings
 
 
@@ -26,6 +27,12 @@ from celery.utils.timeutils import maybe_iso8601
 
 
 from celery.bin.base import Command as BaseCommand, Option
 from celery.bin.base import Command as BaseCommand, Option
 
 
+try:
+    # print_statement does not work with io.StringIO
+    from io import BytesIO as PrintIO
+except ImportError:
+    from StringIO import StringIO as PrintIO  # noqa
+
 HELP = """
 HELP = """
 ---- -- - - ---- Commands- -------------- --- ------------
 ---- -- - - ---- Commands- -------------- --- ------------
 
 
@@ -40,13 +47,18 @@ Migrating task {state.count}/{state.strtotal}: \
 {body[task]}[{body[id]}]\
 {body[task]}[{body[id]}]\
 """
 """
 
 
-commands = {}
+DEBUG = os.environ.get('C_DEBUG', False)
 
 
+commands = {}
 command_classes = [
 command_classes = [
     ('Main', ['worker', 'events', 'beat', 'shell', 'multi', 'amqp'], 'green'),
     ('Main', ['worker', 'events', 'beat', 'shell', 'multi', 'amqp'], 'green'),
     ('Remote Control', ['status', 'inspect', 'control'], 'blue'),
     ('Remote Control', ['status', 'inspect', 'control'], 'blue'),
     ('Utils', ['purge', 'list', 'migrate', 'call', 'result', 'report'], None),
     ('Utils', ['purge', 'list', 'migrate', 'call', 'result', 'report'], None),
 ]
 ]
+if DEBUG:
+    command_classes.append(
+        ('Debug', ['worker_graph', 'consumer_graph'], 'red'),
+    )
 
 
 
 
 @memoize()
 @memoize()
@@ -458,6 +470,26 @@ class result(Command):
         self.out(self.prettify(value)[1])
         self.out(self.prettify(value)[1])
 
 
 
 
+@command
+class worker_graph(Command):
+
+    def run(self, **kwargs):
+        worker = self.app.WorkController()
+        out = PrintIO()
+        worker.namespace.graph.to_dot(out)
+        self.out(out.getvalue())
+
+
+@command
+class consumer_graph(Command):
+
+    def run(self, **kwargs):
+        worker = self.app.WorkController()
+        out = PrintIO()
+        worker.consumer.namespace.graph.to_dot(out)
+        self.out(out.getvalue())
+
+
 class _RemoteControl(Command):
 class _RemoteControl(Command):
     name = None
     name = None
     choices = None
     choices = None

+ 1 - 1
celery/bin/celeryd.py

@@ -197,7 +197,7 @@ class WorkerCommand(Command):
             Option('--autoreload', action='store_true'),
             Option('--autoreload', action='store_true'),
             Option('--no-execv', action='store_true', default=False),
             Option('--no-execv', action='store_true', default=False),
             Option('-D', '--detach', action='store_true'),
             Option('-D', '--detach', action='store_true'),
-        ) + daemon_options()
+        ) + daemon_options() + tuple(self.app.user_options['worker'])
 
 
 
 
 def main():
 def main():

+ 330 - 0
celery/bootsteps.py

@@ -0,0 +1,330 @@
+# -*- coding: utf-8 -*-
+"""
+    celery.bootsteps
+    ~~~~~~~~~~~~~~~~
+
+    The boot-steps!
+
+"""
+from __future__ import absolute_import
+
+from collections import deque
+from importlib import import_module
+from threading import Event
+
+from kombu.common import ignore_errors
+from kombu.utils import symbol_by_name
+
+from .datastructures import DependencyGraph
+from .utils.imports import instantiate, qualname, symbol_by_name
+from .utils.log import get_logger
+from .utils.threads import default_socket_timeout
+
+try:
+    from greenlet import GreenletExit
+    IGNORE_ERRORS = (GreenletExit, )
+except ImportError:  # pragma: no cover
+    IGNORE_ERRORS = ()
+
+#: Default socket timeout at shutdown.
+SHUTDOWN_SOCKET_TIMEOUT = 5.0
+
+#: States
+RUN = 0x1
+CLOSE = 0x2
+TERMINATE = 0x3
+
+logger = get_logger(__name__)
+debug = logger.debug
+
+
+def _pre(ns, fmt):
+    return '| {0}: {1}'.format(ns.alias, fmt)
+
+
+def _maybe_name(s):
+    if not isinstance(s, basestring):
+        return s.name
+    return s
+
+
+class Namespace(object):
+    """A namespace containing bootsteps.
+
+    :keyword steps: List of steps.
+    :keyword name: Set explicit name for this namespace.
+    :keyword app: Set the Celery app for this namespace.
+    :keyword on_start: Optional callback applied after namespace start.
+    :keyword on_close: Optional callback applied before namespace close.
+    :keyword on_stopped: Optional callback applied after namespace stopped.
+
+    """
+    name = None
+    state = None
+    started = 0
+    default_steps = set()
+
+    def __init__(self, steps=None, name=None, app=None, on_start=None,
+            on_close=None, on_stopped=None):
+        self.app = app
+        self.name = name or self.name or qualname(type(self))
+        self.types = set(steps or []) | set(self.default_steps)
+        self.on_start = on_start
+        self.on_close = on_close
+        self.on_stopped = on_stopped
+        self.shutdown_complete = Event()
+        self.steps = {}
+
+    def start(self, parent):
+        self.state = RUN
+        if self.on_start:
+            self.on_start()
+        for i, step in enumerate(filter(None, parent.steps)):
+            self._debug('Starting %s', step.alias)
+            self.started = i + 1
+            step.start(parent)
+            debug('^-- substep ok')
+
+    def close(self, parent):
+        if self.on_close:
+            self.on_close()
+        for step in parent.steps:
+            close = getattr(step, 'close', None)
+            if close:
+                close(parent)
+
+    def restart(self, parent, description='Restarting', attr='stop'):
+        with default_socket_timeout(SHUTDOWN_SOCKET_TIMEOUT):  # Issue 975
+            for step in reversed(parent.steps):
+                if step:
+                    self._debug('%s %s...', description, step.alias)
+                    fun = getattr(step, attr, None)
+                    if fun:
+                        fun(parent)
+
+    def stop(self, parent, close=True, terminate=False):
+        what = 'Terminating' if terminate else 'Stopping'
+        if self.state in (CLOSE, TERMINATE):
+            return
+
+        self.close(parent)
+
+        if self.state != RUN or self.started != len(parent.steps):
+            # Not fully started, can safely exit.
+            self.state = TERMINATE
+            self.shutdown_complete.set()
+            return
+        self.state = CLOSE
+        self.restart(parent, what, 'terminate' if terminate else 'stop')
+
+        if self.on_stopped:
+            self.on_stopped()
+        self.state = TERMINATE
+        self.shutdown_complete.set()
+
+    def join(self, timeout=None):
+        try:
+            # Will only get here if running green,
+            # makes sure all greenthreads have exited.
+            self.shutdown_complete.wait(timeout=timeout)
+        except IGNORE_ERRORS:
+            pass
+
+    def apply(self, parent, **kwargs):
+        """Apply the steps in this namespace to an object.
+
+        This will apply the ``__init__`` and ``include`` methods
+        of each steps with the object as argument.
+
+        For :class:`StartStopStep` the services created
+        will also be added the the objects ``steps`` attribute.
+
+        """
+        self._debug('Loading boot-steps.')
+        order = self.order = []
+        steps = self.steps = self.claim_steps()
+
+        self._debug('Building graph...')
+        for name in self._finalize_boot_steps(steps):
+            step = steps[name] = steps[name](parent, **kwargs)
+            order.append(step)
+            step.include(parent)
+        self._debug('New boot order: {%s}',
+                    ', '.join(s.alias for s in self.order))
+        return self
+
+    def import_module(self, module):
+        return import_module(module)
+
+    def __getitem__(self, name):
+        return self.steps[name]
+
+    def _find_last(self):
+        for C in self.steps.itervalues():
+            if C.last:
+                return C
+
+    def _firstpass(self, steps):
+        stream = deque(step.requires for step in steps.itervalues())
+        while stream:
+            for node in stream.popleft():
+                node = symbol_by_name(node)
+                if node.name not in self.steps:
+                    steps[node.name] = node
+                stream.append(node.requires)
+        for node in steps.itervalues():
+            node.requires = [_maybe_name(n) for n in node.requires]
+        for step in steps.values():
+            [steps[n] for n in step.requires]
+
+    def _finalize_boot_steps(self, steps):
+        self._firstpass(steps)
+        G = self.graph = DependencyGraph((C.name, C.requires)
+                            for C in steps.itervalues())
+        last = self._find_last()
+        if last:
+            for obj in G:
+                if obj != last.name:
+                    G.add_edge(last.name, obj)
+        try:
+            return G.topsort()
+        except KeyError as exc:
+            raise KeyError('unknown boot-step: %s' % exc)
+
+    def claim_steps(self):
+        return dict(self.load_step(step) for step in self._all_steps())
+
+    def _all_steps(self):
+        return self.types | self.app.steps[self.name]
+
+    def load_step(self, step):
+        step = symbol_by_name(step)
+        return step.name, step
+
+    def _debug(self, msg, *args):
+        return debug(_pre(self, msg), *args)
+
+    @property
+    def alias(self):
+        return self.name.rsplit('.', 1)[-1]
+
+
+class StepType(type):
+    """Metaclass for steps."""
+
+    def __new__(cls, name, bases, attrs):
+        module = attrs.get('__module__')
+        qname = '{0}.{1}'.format(module, name) if module else name
+        attrs.update(
+            __qualname__=qname,
+            name=attrs.get('name') or qname,
+            requires=attrs.get('requires', ()),
+        )
+        return super(StepType, cls).__new__(cls, name, bases, attrs)
+
+    def __repr__(self):
+        return 'step:{0.name}{{{0.requires!r}}}'.format(self)
+
+
+class Step(object):
+    """A Bootstep.
+
+    The :meth:`__init__` method is called when the step
+    is bound to a parent object, and can as such be used
+    to initialize attributes in the parent object at
+    parent instantiation-time.
+
+    """
+    __metaclass__ = StepType
+
+    #: Optional step name, will use qualname if not specified.
+    name = None
+
+    #: List of other steps that that must be started before this step.
+    #: Note that all dependencies must be in the same namespace.
+    requires = ()
+
+    #: Optional obj created by the :meth:`create` method.
+    #: This is used by :class:`StartStopStep` to keep the
+    #: original service object.
+    obj = None
+
+    #: This flag is reserved for the workers Consumer,
+    #: since it is required to always be started last.
+    #: There can only be one object marked with lsat
+    #: in every namespace.
+    last = False
+
+    #: This provides the default for :meth:`include_if`.
+    enabled = True
+
+    def __init__(self, parent, **kwargs):
+        pass
+
+    def create(self, parent):
+        """Create the step."""
+        pass
+
+    def include_if(self, parent):
+        """An optional predicate that decided whether this
+        step should be created."""
+        return self.enabled
+
+    def instantiate(self, name, *args, **kwargs):
+        return instantiate(name, *args, **kwargs)
+
+    def include(self, parent):
+        if self.include_if(parent):
+            self.obj = self.create(parent)
+            return True
+
+    def __repr__(self):
+        return '<step: {0.alias}>'.format(self)
+
+    @property
+    def alias(self):
+        return self.name.rsplit('.', 1)[-1]
+
+
+class StartStopStep(Step):
+
+    def start(self, parent):
+        if self.obj:
+            return self.obj.start()
+
+    def stop(self, parent):
+        if self.obj:
+            return self.obj.stop()
+
+    def close(self, parent):
+        pass
+
+    def terminate(self, parent):
+        self.stop(parent)
+
+    def include(self, parent):
+        if super(StartStopStep, self).include(parent):
+            parent.steps.append(self)
+
+
+class ConsumerStep(StartStopStep):
+    requires = ('Connection', )
+    consumers = None
+
+    def get_consumers(self, channel):
+        raise NotImplementedError('missing get_consumers')
+
+    def start(self, c):
+        self.consumers = self.get_consumers(c.connection)
+        for consumer in self.consumers or []:
+            consumer.consume()
+
+    def stop(self, c):
+        for consumer in self.consumers or []:
+            ignore_errors(c.connection, consumer.cancel)
+
+    def shutdown(self, c):
+        self.stop(c)
+        for consumer in self.consumers or []:
+            if consumer.channel:
+                ignore_errors(c.connection, consumer.channel.close)

+ 3 - 8
celery/tests/bin/test_celeryd.py

@@ -60,10 +60,7 @@ def disable_stdouts(fun):
 
 
 
 
 class Worker(cd.Worker):
 class Worker(cd.Worker):
-
-    def __init__(self, *args, **kwargs):
-        super(Worker, self).__init__(*args, **kwargs)
-        self.redirect_stdouts = False
+    redirect_stdouts = False
 
 
     def start(self, *args, **kwargs):
     def start(self, *args, **kwargs):
         self.on_start()
         self.on_start()
@@ -292,9 +289,7 @@ class test_Worker(WorkerAppCase):
 
 
     @disable_stdouts
     @disable_stdouts
     def test_redirect_stdouts(self):
     def test_redirect_stdouts(self):
-        worker = self.Worker()
-        worker.redirect_stdouts = False
-        worker.setup_logging()
+        self.Worker(redirect_stdouts=False)
         with self.assertRaises(AttributeError):
         with self.assertRaises(AttributeError):
             sys.stdout.logger
             sys.stdout.logger
 
 
@@ -306,7 +301,7 @@ class test_Worker(WorkerAppCase):
             logging_setup[0] = True
             logging_setup[0] = True
 
 
         try:
         try:
-            worker = self.Worker()
+            worker = self.Worker(redirect_stdouts=False)
             worker.app.log.__class__._setup = False
             worker.app.log.__class__._setup = False
             worker.setup_logging()
             worker.setup_logging()
             self.assertTrue(logging_setup[0])
             self.assertTrue(logging_setup[0])

+ 43 - 84
celery/tests/worker/test_bootsteps.py

@@ -2,37 +2,26 @@ from __future__ import absolute_import
 
 
 from mock import Mock
 from mock import Mock
 
 
-from celery.worker import bootsteps
+from celery import bootsteps
 
 
 from celery.tests.utils import AppCase, Case
 from celery.tests.utils import AppCase, Case
 
 
 
 
-class test_Component(Case):
+class test_Step(Case):
 
 
-    class Def(bootsteps.Component):
-        name = 'test_Component.Def'
-
-    def test_components_must_be_named(self):
-        with self.assertRaises(NotImplementedError):
-
-            class X(bootsteps.Component):
-                pass
-
-        class Y(bootsteps.Component):
-            abstract = True
+    class Def(bootsteps.Step):
+        name = 'test_Step.Def'
 
 
     def test_namespace_name(self, ns='test_namespace_name'):
     def test_namespace_name(self, ns='test_namespace_name'):
 
 
-        class X(bootsteps.Component):
+        class X(bootsteps.Step):
             namespace = ns
             namespace = ns
             name = 'X'
             name = 'X'
-        self.assertEqual(X.namespace, ns)
         self.assertEqual(X.name, 'X')
         self.assertEqual(X.name, 'X')
 
 
-        class Y(bootsteps.Component):
-            name = '%s.Y' % (ns, )
-        self.assertEqual(Y.namespace, ns)
-        self.assertEqual(Y.name, 'Y')
+        class Y(bootsteps.Step):
+            name = '%s.Y' % ns
+        self.assertEqual(Y.name, '%s.Y' % ns)
 
 
     def test_init(self):
     def test_init(self):
         self.assertTrue(self.Def(self))
         self.assertTrue(self.Def(self))
@@ -70,13 +59,13 @@ class test_Component(Case):
         self.assertFalse(x.create.call_count)
         self.assertFalse(x.create.call_count)
 
 
 
 
-class test_StartStopComponent(Case):
+class test_StartStopStep(Case):
 
 
-    class Def(bootsteps.StartStopComponent):
-        name = 'test_StartStopComponent.Def'
+    class Def(bootsteps.StartStopStep):
+        name = 'test_StartStopStep.Def'
 
 
     def setUp(self):
     def setUp(self):
-        self.components = []
+        self.steps = []
 
 
     def test_start__stop(self):
     def test_start__stop(self):
         x = self.Def(self)
         x = self.Def(self)
@@ -84,10 +73,10 @@ class test_StartStopComponent(Case):
 
 
         # include creates the underlying object and sets
         # include creates the underlying object and sets
         # its x.obj attribute to it, as well as appending
         # its x.obj attribute to it, as well as appending
-        # it to the parent.components list.
+        # it to the parent.steps list.
         x.include(self)
         x.include(self)
-        self.assertTrue(self.components)
-        self.assertIs(self.components[0], x)
+        self.assertTrue(self.steps)
+        self.assertIs(self.steps[0], x)
 
 
         x.start(self)
         x.start(self)
         x.obj.start.assert_called_with()
         x.obj.start.assert_called_with()
@@ -99,7 +88,7 @@ class test_StartStopComponent(Case):
         x = self.Def(self)
         x = self.Def(self)
         x.enabled = False
         x.enabled = False
         x.include(self)
         x.include(self)
-        self.assertFalse(self.components)
+        self.assertFalse(self.steps)
 
 
     def test_terminate(self):
     def test_terminate(self):
         x = self.Def(self)
         x = self.Def(self)
@@ -116,47 +105,29 @@ class test_Namespace(AppCase):
     class NS(bootsteps.Namespace):
     class NS(bootsteps.Namespace):
         name = 'test_Namespace'
         name = 'test_Namespace'
 
 
-    class ImportingNS(bootsteps.Namespace):
-
-        def __init__(self, *args, **kwargs):
-            bootsteps.Namespace.__init__(self, *args, **kwargs)
-            self.imported = []
-
-        def modules(self):
-            return ['A', 'B', 'C']
+    def test_steps_added_to_unclaimed(self):
 
 
-        def import_module(self, module):
-            self.imported.append(module)
-
-    def test_components_added_to_unclaimed(self):
-
-        class tnA(bootsteps.Component):
+        class tnA(bootsteps.Step):
             name = 'test_Namespace.A'
             name = 'test_Namespace.A'
 
 
-        class tnB(bootsteps.Component):
+        class tnB(bootsteps.Step):
             name = 'test_Namespace.B'
             name = 'test_Namespace.B'
 
 
-        class xxA(bootsteps.Component):
+        class xxA(bootsteps.Step):
             name = 'xx.A'
             name = 'xx.A'
 
 
-        self.assertIn('A', self.NS._unclaimed['test_Namespace'])
-        self.assertIn('B', self.NS._unclaimed['test_Namespace'])
-        self.assertIn('A', self.NS._unclaimed['xx'])
-        self.assertNotIn('B', self.NS._unclaimed['xx'])
+        class NS(self.NS):
+            default_steps = [tnA, tnB]
+        ns = NS(app=self.app)
+
+        self.assertIn(tnA, ns._all_steps())
+        self.assertIn(tnB, ns._all_steps())
+        self.assertNotIn(xxA, ns._all_steps())
 
 
     def test_init(self):
     def test_init(self):
         ns = self.NS(app=self.app)
         ns = self.NS(app=self.app)
         self.assertIs(ns.app, self.app)
         self.assertIs(ns.app, self.app)
         self.assertEqual(ns.name, 'test_Namespace')
         self.assertEqual(ns.name, 'test_Namespace')
-        self.assertFalse(ns.services)
-
-    def test_interface_modules(self):
-        self.NS(app=self.app).modules()
-
-    def test_load_modules(self):
-        x = self.ImportingNS(app=self.app)
-        x.load_modules()
-        self.assertListEqual(x.imported, ['A', 'B', 'C'])
 
 
     def test_apply(self):
     def test_apply(self):
 
 
@@ -166,44 +137,32 @@ class test_Namespace(AppCase):
             def modules(self):
             def modules(self):
                 return ['A', 'B']
                 return ['A', 'B']
 
 
-        class A(bootsteps.Component):
-            name = 'test_apply.A'
-            requires = ['C']
-
-        class B(bootsteps.Component):
+        class B(bootsteps.Step):
             name = 'test_apply.B'
             name = 'test_apply.B'
 
 
-        class C(bootsteps.Component):
+        class C(bootsteps.Step):
             name = 'test_apply.C'
             name = 'test_apply.C'
-            requires = ['B']
+            requires = [B]
 
 
-        class D(bootsteps.Component):
+        class A(bootsteps.Step):
+            name = 'test_apply.A'
+            requires = [C]
+
+        class D(bootsteps.Step):
             name = 'test_apply.D'
             name = 'test_apply.D'
             last = True
             last = True
 
 
-        x = MyNS(app=self.app)
-        x.import_module = Mock()
+        x = MyNS([A, D], app=self.app)
         x.apply(self)
         x.apply(self)
 
 
-        self.assertItemsEqual(x.components.values(), [A, B, C, D])
-        self.assertTrue(x.import_module.call_count)
-
-        for boot_step in x.boot_steps:
-            self.assertEqual(boot_step.namespace, x)
-
-        self.assertIsInstance(x.boot_steps[0], B)
-        self.assertIsInstance(x.boot_steps[1], C)
-        self.assertIsInstance(x.boot_steps[2], A)
-        self.assertIsInstance(x.boot_steps[3], D)
-
-        self.assertIs(x['A'], A)
-
-    def test_import_module(self):
-        x = self.NS(app=self.app)
-        import os
-        self.assertIs(x.import_module('os'), os)
+        self.assertIsInstance(x.order[0], B)
+        self.assertIsInstance(x.order[1], C)
+        self.assertIsInstance(x.order[2], A)
+        self.assertIsInstance(x.order[3], D)
+        self.assertIn(A, x.types)
+        self.assertIs(x[A.name], x.order[2])
 
 
-    def test_find_last_but_no_components(self):
+    def test_find_last_but_no_steps(self):
 
 
         class MyNS(bootsteps.Namespace):
         class MyNS(bootsteps.Namespace):
             name = 'qwejwioqjewoqiej'
             name = 'qwejwioqjewoqiej'

+ 207 - 144
celery/tests/worker/test_worker.py

@@ -9,7 +9,7 @@ from Queue import Empty
 
 
 from billiard.exceptions import WorkerLostError
 from billiard.exceptions import WorkerLostError
 from kombu import Connection
 from kombu import Connection
-from kombu.common import QoS, PREFETCH_COUNT_MAX
+from kombu.common import QoS, PREFETCH_COUNT_MAX, ignore_errors
 from kombu.exceptions import StdChannelError
 from kombu.exceptions import StdChannelError
 from kombu.transport.base import Message
 from kombu.transport.base import Message
 from mock import Mock, patch
 from mock import Mock, patch
@@ -17,6 +17,7 @@ from nose import SkipTest
 
 
 from celery import current_app
 from celery import current_app
 from celery.app.defaults import DEFAULTS
 from celery.app.defaults import DEFAULTS
+from celery.bootsteps import RUN, CLOSE, TERMINATE, StartStopStep
 from celery.concurrency.base import BasePool
 from celery.concurrency.base import BasePool
 from celery.datastructures import AttributeDict
 from celery.datastructures import AttributeDict
 from celery.exceptions import SystemTerminate
 from celery.exceptions import SystemTerminate
@@ -24,33 +25,51 @@ from celery.task import task as task_dec
 from celery.task import periodic_task as periodic_task_dec
 from celery.task import periodic_task as periodic_task_dec
 from celery.utils import uuid
 from celery.utils import uuid
 from celery.worker import WorkController
 from celery.worker import WorkController
-from celery.worker.components import Queues, Timers, EvLoop, Pool
-from celery.worker.bootsteps import RUN, CLOSE, TERMINATE, StartStopComponent
+from celery.worker.components import Queues, Timers, Hub, Pool
 from celery.worker.buckets import FastQueue
 from celery.worker.buckets import FastQueue
 from celery.worker.job import Request
 from celery.worker.job import Request
-from celery.worker.consumer import BlockingConsumer
+from celery.worker import consumer
+from celery.worker.consumer import Consumer
 from celery.utils.serialization import pickle
 from celery.utils.serialization import pickle
 from celery.utils.timer2 import Timer
 from celery.utils.timer2 import Timer
 
 
 from celery.tests.utils import AppCase, Case
 from celery.tests.utils import AppCase, Case
 
 
 
 
+def MockStep(step=None):
+    step = Mock() if step is None else step
+    step.namespace = Mock()
+    step.namespace.name = 'MockNS'
+    step.name = 'MockStep'
+    return step
+
+
 class PlaceHolder(object):
 class PlaceHolder(object):
         pass
         pass
 
 
 
 
-class MyKombuConsumer(BlockingConsumer):
+def find_step(obj, typ):
+    return obj.namespace.steps[typ.name]
+
+
+class _MyKombuConsumer(Consumer):
     broadcast_consumer = Mock()
     broadcast_consumer = Mock()
     task_consumer = Mock()
     task_consumer = Mock()
 
 
     def __init__(self, *args, **kwargs):
     def __init__(self, *args, **kwargs):
         kwargs.setdefault('pool', BasePool(2))
         kwargs.setdefault('pool', BasePool(2))
-        super(MyKombuConsumer, self).__init__(*args, **kwargs)
+        super(_MyKombuConsumer, self).__init__(*args, **kwargs)
 
 
     def restart_heartbeat(self):
     def restart_heartbeat(self):
         self.heart = None
         self.heart = None
 
 
 
 
+class MyKombuConsumer(Consumer):
+
+    def loop(self, *args, **kwargs):
+        pass
+
+
 class MockNode(object):
 class MockNode(object):
     commands = []
     commands = []
 
 
@@ -227,90 +246,102 @@ class test_Consumer(Case):
 
 
     def test_start_when_closed(self):
     def test_start_when_closed(self):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
-        l._state = CLOSE
+        l.namespace.state = CLOSE
         l.start()
         l.start()
 
 
     def test_connection(self):
     def test_connection(self):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
 
 
-        l.reset_connection()
+        l.namespace.start(l)
         self.assertIsInstance(l.connection, Connection)
         self.assertIsInstance(l.connection, Connection)
 
 
-        l._state = RUN
+        l.namespace.state = RUN
         l.event_dispatcher = None
         l.event_dispatcher = None
-        l.stop_consumers(close_connection=False)
+        l.namespace.restart(l)
         self.assertTrue(l.connection)
         self.assertTrue(l.connection)
 
 
-        l._state = RUN
-        l.stop_consumers()
+        l.namespace.state = RUN
+        l.shutdown()
         self.assertIsNone(l.connection)
         self.assertIsNone(l.connection)
         self.assertIsNone(l.task_consumer)
         self.assertIsNone(l.task_consumer)
 
 
-        l.reset_connection()
+        l.namespace.start(l)
         self.assertIsInstance(l.connection, Connection)
         self.assertIsInstance(l.connection, Connection)
-        l.stop_consumers()
+        l.namespace.restart(l)
 
 
         l.stop()
         l.stop()
-        l.close_connection()
+        l.shutdown()
         self.assertIsNone(l.connection)
         self.assertIsNone(l.connection)
         self.assertIsNone(l.task_consumer)
         self.assertIsNone(l.task_consumer)
 
 
     def test_close_connection(self):
     def test_close_connection(self):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
-        l._state = RUN
-        l.close_connection()
+        l.namespace.state = RUN
+        step = find_step(l, consumer.Connection)
+        conn = l.connection = Mock()
+        step.shutdown(l)
+        self.assertTrue(conn.close.called)
+        self.assertIsNone(l.connection)
 
 
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         eventer = l.event_dispatcher = Mock()
         eventer = l.event_dispatcher = Mock()
         eventer.enabled = True
         eventer.enabled = True
         heart = l.heart = MockHeart()
         heart = l.heart = MockHeart()
-        l._state = RUN
-        l.stop_consumers()
+        l.namespace.state = RUN
+        Events = find_step(l, consumer.Events)
+        Events.shutdown(l)
+        Heart = find_step(l, consumer.Heart)
+        Heart.shutdown(l)
         self.assertTrue(eventer.close.call_count)
         self.assertTrue(eventer.close.call_count)
         self.assertTrue(heart.closed)
         self.assertTrue(heart.closed)
 
 
     @patch('celery.worker.consumer.warn')
     @patch('celery.worker.consumer.warn')
     def test_receive_message_unknown(self, warn):
     def test_receive_message_unknown(self, warn):
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = _MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l.steps.pop()
         backend = Mock()
         backend = Mock()
         m = create_message(backend, unknown={'baz': '!!!'})
         m = create_message(backend, unknown={'baz': '!!!'})
         l.event_dispatcher = Mock()
         l.event_dispatcher = Mock()
-        l.pidbox_node = MockNode()
+        l.node = MockNode()
 
 
-        l.receive_message(m.decode(), m)
+        callback = self._get_on_message(l)
+        callback(m.decode(), m)
         self.assertTrue(warn.call_count)
         self.assertTrue(warn.call_count)
 
 
-    @patch('celery.utils.timer2.to_timestamp')
+    @patch('celery.worker.consumer.to_timestamp')
     def test_receive_message_eta_OverflowError(self, to_timestamp):
     def test_receive_message_eta_OverflowError(self, to_timestamp):
         to_timestamp.side_effect = OverflowError()
         to_timestamp.side_effect = OverflowError()
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = _MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l.steps.pop()
         m = create_message(Mock(), task=foo_task.name,
         m = create_message(Mock(), task=foo_task.name,
                                    args=('2, 2'),
                                    args=('2, 2'),
                                    kwargs={},
                                    kwargs={},
                                    eta=datetime.now().isoformat())
                                    eta=datetime.now().isoformat())
         l.event_dispatcher = Mock()
         l.event_dispatcher = Mock()
-        l.pidbox_node = MockNode()
+        l.node = MockNode()
         l.update_strategies()
         l.update_strategies()
 
 
-        l.receive_message(m.decode(), m)
+        callback = self._get_on_message(l)
+        callback(m.decode(), m)
         self.assertTrue(m.acknowledged)
         self.assertTrue(m.acknowledged)
         self.assertTrue(to_timestamp.call_count)
         self.assertTrue(to_timestamp.call_count)
 
 
     @patch('celery.worker.consumer.error')
     @patch('celery.worker.consumer.error')
     def test_receive_message_InvalidTaskError(self, error):
     def test_receive_message_InvalidTaskError(self, error):
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = _MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l.steps.pop()
         m = create_message(Mock(), task=foo_task.name,
         m = create_message(Mock(), task=foo_task.name,
                            args=(1, 2), kwargs='foobarbaz', id=1)
                            args=(1, 2), kwargs='foobarbaz', id=1)
         l.update_strategies()
         l.update_strategies()
         l.event_dispatcher = Mock()
         l.event_dispatcher = Mock()
-        l.pidbox_node = MockNode()
 
 
-        l.receive_message(m.decode(), m)
+        callback = self._get_on_message(l)
+        callback(m.decode(), m)
         self.assertIn('Received invalid task message', error.call_args[0][0])
         self.assertIn('Received invalid task message', error.call_args[0][0])
 
 
     @patch('celery.worker.consumer.crit')
     @patch('celery.worker.consumer.crit')
     def test_on_decode_error(self, crit):
     def test_on_decode_error(self, crit):
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = Consumer(self.ready_queue, timer=self.timer)
 
 
         class MockMessage(Mock):
         class MockMessage(Mock):
             content_type = 'application/x-msgpack'
             content_type = 'application/x-msgpack'
@@ -322,14 +353,25 @@ class test_Consumer(Case):
         self.assertTrue(message.ack.call_count)
         self.assertTrue(message.ack.call_count)
         self.assertIn("Can't decode message body", crit.call_args[0][0])
         self.assertIn("Can't decode message body", crit.call_args[0][0])
 
 
+    def _get_on_message(self, l):
+        l.qos = Mock()
+        l.event_dispatcher = Mock()
+        l.task_consumer = Mock()
+        l.connection = Mock()
+        l.connection.drain_events.side_effect = SystemExit()
+
+        with self.assertRaises(SystemExit):
+            l.loop(*l.loop_args())
+        self.assertTrue(l.task_consumer.register_callback.called)
+        return l.task_consumer.register_callback.call_args[0][0]
+
     def test_receieve_message(self):
     def test_receieve_message(self):
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = Consumer(self.ready_queue, timer=self.timer)
         m = create_message(Mock(), task=foo_task.name,
         m = create_message(Mock(), task=foo_task.name,
                            args=[2, 4, 8], kwargs={})
                            args=[2, 4, 8], kwargs={})
         l.update_strategies()
         l.update_strategies()
-
-        l.event_dispatcher = Mock()
-        l.receive_message(m.decode(), m)
+        callback = self._get_on_message(l)
+        callback(m.decode(), m)
 
 
         in_bucket = self.ready_queue.get_nowait()
         in_bucket = self.ready_queue.get_nowait()
         self.assertIsInstance(in_bucket, Request)
         self.assertIsInstance(in_bucket, Request)
@@ -339,10 +381,10 @@ class test_Consumer(Case):
 
 
     def test_start_connection_error(self):
     def test_start_connection_error(self):
 
 
-        class MockConsumer(BlockingConsumer):
+        class MockConsumer(Consumer):
             iterations = 0
             iterations = 0
 
 
-            def consume_messages(self):
+            def loop(self, *args, **kwargs):
                 if not self.iterations:
                 if not self.iterations:
                     self.iterations = 1
                     self.iterations = 1
                     raise KeyError('foo')
                     raise KeyError('foo')
@@ -360,10 +402,10 @@ class test_Consumer(Case):
         # Regression test for AMQPChannelExceptions that can occur within the
         # Regression test for AMQPChannelExceptions that can occur within the
         # consumer. (i.e. 404 errors)
         # consumer. (i.e. 404 errors)
 
 
-        class MockConsumer(BlockingConsumer):
+        class MockConsumer(Consumer):
             iterations = 0
             iterations = 0
 
 
-            def consume_messages(self):
+            def loop(self, *args, **kwargs):
                 if not self.iterations:
                 if not self.iterations:
                     self.iterations = 1
                     self.iterations = 1
                     raise KeyError('foo')
                     raise KeyError('foo')
@@ -377,7 +419,7 @@ class test_Consumer(Case):
         l.heart.stop()
         l.heart.stop()
         l.timer.stop()
         l.timer.stop()
 
 
-    def test_consume_messages_ignores_socket_timeout(self):
+    def test_loop_ignores_socket_timeout(self):
 
 
         class Connection(current_app.connection().__class__):
         class Connection(current_app.connection().__class__):
             obj = None
             obj = None
@@ -391,9 +433,9 @@ class test_Consumer(Case):
         l.task_consumer = Mock()
         l.task_consumer = Mock()
         l.connection.obj = l
         l.connection.obj = l
         l.qos = QoS(l.task_consumer, 10)
         l.qos = QoS(l.task_consumer, 10)
-        l.consume_messages()
+        l.loop(*l.loop_args())
 
 
-    def test_consume_messages_when_socket_error(self):
+    def test_loop_when_socket_error(self):
 
 
         class Connection(current_app.connection().__class__):
         class Connection(current_app.connection().__class__):
             obj = None
             obj = None
@@ -402,20 +444,20 @@ class test_Consumer(Case):
                 self.obj.connection = None
                 self.obj.connection = None
                 raise socket.error('foo')
                 raise socket.error('foo')
 
 
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
-        l._state = RUN
+        l = Consumer(self.ready_queue, timer=self.timer)
+        l.namespace.state = RUN
         c = l.connection = Connection()
         c = l.connection = Connection()
         l.connection.obj = l
         l.connection.obj = l
         l.task_consumer = Mock()
         l.task_consumer = Mock()
         l.qos = QoS(l.task_consumer, 10)
         l.qos = QoS(l.task_consumer, 10)
         with self.assertRaises(socket.error):
         with self.assertRaises(socket.error):
-            l.consume_messages()
+            l.loop(*l.loop_args())
 
 
-        l._state = CLOSE
+        l.namespace.state = CLOSE
         l.connection = c
         l.connection = c
-        l.consume_messages()
+        l.loop(*l.loop_args())
 
 
-    def test_consume_messages(self):
+    def test_loop(self):
 
 
         class Connection(current_app.connection().__class__):
         class Connection(current_app.connection().__class__):
             obj = None
             obj = None
@@ -423,14 +465,14 @@ class test_Consumer(Case):
             def drain_events(self, **kwargs):
             def drain_events(self, **kwargs):
                 self.obj.connection = None
                 self.obj.connection = None
 
 
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = Consumer(self.ready_queue, timer=self.timer)
         l.connection = Connection()
         l.connection = Connection()
         l.connection.obj = l
         l.connection.obj = l
         l.task_consumer = Mock()
         l.task_consumer = Mock()
         l.qos = QoS(l.task_consumer, 10)
         l.qos = QoS(l.task_consumer, 10)
 
 
-        l.consume_messages()
-        l.consume_messages()
+        l.loop(*l.loop_args())
+        l.loop(*l.loop_args())
         self.assertTrue(l.task_consumer.consume.call_count)
         self.assertTrue(l.task_consumer.consume.call_count)
         l.task_consumer.qos.assert_called_with(prefetch_count=10)
         l.task_consumer.qos.assert_called_with(prefetch_count=10)
         l.task_consumer.qos = Mock()
         l.task_consumer.qos = Mock()
@@ -441,15 +483,15 @@ class test_Consumer(Case):
         self.assertEqual(l.qos.value, 9)
         self.assertEqual(l.qos.value, 9)
         l.task_consumer.qos.assert_called_with(prefetch_count=9)
         l.task_consumer.qos.assert_called_with(prefetch_count=9)
 
 
-    def test_maybe_conn_error(self):
+    def test_ignore_errors(self):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         l.connection_errors = (KeyError, )
         l.connection_errors = (KeyError, )
         l.channel_errors = (SyntaxError, )
         l.channel_errors = (SyntaxError, )
-        l.maybe_conn_error(Mock(side_effect=AttributeError('foo')))
-        l.maybe_conn_error(Mock(side_effect=KeyError('foo')))
-        l.maybe_conn_error(Mock(side_effect=SyntaxError('foo')))
+        ignore_errors(l, Mock(side_effect=AttributeError('foo')))
+        ignore_errors(l, Mock(side_effect=KeyError('foo')))
+        ignore_errors(l, Mock(side_effect=SyntaxError('foo')))
         with self.assertRaises(IndexError):
         with self.assertRaises(IndexError):
-            l.maybe_conn_error(Mock(side_effect=IndexError('foo')))
+            ignore_errors(l, Mock(side_effect=IndexError('foo')))
 
 
     def test_apply_eta_task(self):
     def test_apply_eta_task(self):
         from celery.worker import state
         from celery.worker import state
@@ -464,18 +506,20 @@ class test_Consumer(Case):
         self.assertIs(self.ready_queue.get_nowait(), task)
         self.assertIs(self.ready_queue.get_nowait(), task)
 
 
     def test_receieve_message_eta_isoformat(self):
     def test_receieve_message_eta_isoformat(self):
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = _MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l.steps.pop()
         m = create_message(Mock(), task=foo_task.name,
         m = create_message(Mock(), task=foo_task.name,
                            eta=datetime.now().isoformat(),
                            eta=datetime.now().isoformat(),
                            args=[2, 4, 8], kwargs={})
                            args=[2, 4, 8], kwargs={})
 
 
         l.task_consumer = Mock()
         l.task_consumer = Mock()
-        l.qos = QoS(l.task_consumer, l.initial_prefetch_count)
+        l.qos = QoS(l.task_consumer, 1)
         current_pcount = l.qos.value
         current_pcount = l.qos.value
         l.event_dispatcher = Mock()
         l.event_dispatcher = Mock()
         l.enabled = False
         l.enabled = False
         l.update_strategies()
         l.update_strategies()
-        l.receive_message(m.decode(), m)
+        callback = self._get_on_message(l)
+        callback(m.decode(), m)
         l.timer.stop()
         l.timer.stop()
         l.timer.join(1)
         l.timer.join(1)
 
 
@@ -488,28 +532,30 @@ class test_Consumer(Case):
         self.assertGreater(l.qos.value, current_pcount)
         self.assertGreater(l.qos.value, current_pcount)
         l.timer.stop()
         l.timer.stop()
 
 
-    def test_on_control(self):
+    def test_pidbox_callback(self):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
-        l.pidbox_node = Mock()
-        l.reset_pidbox_node = Mock()
+        con = find_step(l, consumer.Control).box
+        con.node = Mock()
+        con.reset = Mock()
 
 
-        l.on_control('foo', 'bar')
-        l.pidbox_node.handle_message.assert_called_with('foo', 'bar')
+        con.on_message('foo', 'bar')
+        con.node.handle_message.assert_called_with('foo', 'bar')
 
 
-        l.pidbox_node = Mock()
-        l.pidbox_node.handle_message.side_effect = KeyError('foo')
-        l.on_control('foo', 'bar')
-        l.pidbox_node.handle_message.assert_called_with('foo', 'bar')
+        con.node = Mock()
+        con.node.handle_message.side_effect = KeyError('foo')
+        con.on_message('foo', 'bar')
+        con.node.handle_message.assert_called_with('foo', 'bar')
 
 
-        l.pidbox_node = Mock()
-        l.pidbox_node.handle_message.side_effect = ValueError('foo')
-        l.on_control('foo', 'bar')
-        l.pidbox_node.handle_message.assert_called_with('foo', 'bar')
-        l.reset_pidbox_node.assert_called_with()
+        con.node = Mock()
+        con.node.handle_message.side_effect = ValueError('foo')
+        con.on_message('foo', 'bar')
+        con.node.handle_message.assert_called_with('foo', 'bar')
+        self.assertTrue(con.reset.called)
 
 
     def test_revoke(self):
     def test_revoke(self):
         ready_queue = FastQueue()
         ready_queue = FastQueue()
-        l = MyKombuConsumer(ready_queue, timer=self.timer)
+        l = _MyKombuConsumer(ready_queue, timer=self.timer)
+        l.steps.pop()
         backend = Mock()
         backend = Mock()
         id = uuid()
         id = uuid()
         t = create_message(backend, task=foo_task.name, args=[2, 4, 8],
         t = create_message(backend, task=foo_task.name, args=[2, 4, 8],
@@ -517,16 +563,19 @@ class test_Consumer(Case):
         from celery.worker.state import revoked
         from celery.worker.state import revoked
         revoked.add(id)
         revoked.add(id)
 
 
-        l.receive_message(t.decode(), t)
+        callback = self._get_on_message(l)
+        callback(t.decode(), t)
         self.assertTrue(ready_queue.empty())
         self.assertTrue(ready_queue.empty())
 
 
     def test_receieve_message_not_registered(self):
     def test_receieve_message_not_registered(self):
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = _MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l.steps.pop()
         backend = Mock()
         backend = Mock()
         m = create_message(backend, task='x.X.31x', args=[2, 4, 8], kwargs={})
         m = create_message(backend, task='x.X.31x', args=[2, 4, 8], kwargs={})
 
 
         l.event_dispatcher = Mock()
         l.event_dispatcher = Mock()
-        self.assertFalse(l.receive_message(m.decode(), m))
+        callback = self._get_on_message(l)
+        self.assertFalse(callback(m.decode(), m))
         with self.assertRaises(Empty):
         with self.assertRaises(Empty):
             self.ready_queue.get_nowait()
             self.ready_queue.get_nowait()
         self.assertTrue(self.timer.empty())
         self.assertTrue(self.timer.empty())
@@ -534,7 +583,7 @@ class test_Consumer(Case):
     @patch('celery.worker.consumer.warn')
     @patch('celery.worker.consumer.warn')
     @patch('celery.worker.consumer.logger')
     @patch('celery.worker.consumer.logger')
     def test_receieve_message_ack_raises(self, logger, warn):
     def test_receieve_message_ack_raises(self, logger, warn):
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = Consumer(self.ready_queue, timer=self.timer)
         backend = Mock()
         backend = Mock()
         m = create_message(backend, args=[2, 4, 8], kwargs={})
         m = create_message(backend, args=[2, 4, 8], kwargs={})
 
 
@@ -542,7 +591,8 @@ class test_Consumer(Case):
         l.connection_errors = (socket.error, )
         l.connection_errors = (socket.error, )
         m.reject = Mock()
         m.reject = Mock()
         m.reject.side_effect = socket.error('foo')
         m.reject.side_effect = socket.error('foo')
-        self.assertFalse(l.receive_message(m.decode(), m))
+        callback = self._get_on_message(l)
+        self.assertFalse(callback(m.decode(), m))
         self.assertTrue(warn.call_count)
         self.assertTrue(warn.call_count)
         with self.assertRaises(Empty):
         with self.assertRaises(Empty):
             self.ready_queue.get_nowait()
             self.ready_queue.get_nowait()
@@ -551,7 +601,8 @@ class test_Consumer(Case):
         self.assertTrue(logger.critical.call_count)
         self.assertTrue(logger.critical.call_count)
 
 
     def test_receieve_message_eta(self):
     def test_receieve_message_eta(self):
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = _MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l.steps.pop()
         l.event_dispatcher = Mock()
         l.event_dispatcher = Mock()
         l.event_dispatcher._outbound_buffer = deque()
         l.event_dispatcher._outbound_buffer = deque()
         backend = Mock()
         backend = Mock()
@@ -560,16 +611,17 @@ class test_Consumer(Case):
                            eta=(datetime.now() +
                            eta=(datetime.now() +
                                timedelta(days=1)).isoformat())
                                timedelta(days=1)).isoformat())
 
 
-        l.reset_connection()
+        l.namespace.start(l)
         p = l.app.conf.BROKER_CONNECTION_RETRY
         p = l.app.conf.BROKER_CONNECTION_RETRY
         l.app.conf.BROKER_CONNECTION_RETRY = False
         l.app.conf.BROKER_CONNECTION_RETRY = False
         try:
         try:
-            l.reset_connection()
+            l.namespace.start(l)
         finally:
         finally:
             l.app.conf.BROKER_CONNECTION_RETRY = p
             l.app.conf.BROKER_CONNECTION_RETRY = p
-        l.stop_consumers()
+        l.namespace.restart(l)
         l.event_dispatcher = Mock()
         l.event_dispatcher = Mock()
-        l.receive_message(m.decode(), m)
+        callback = self._get_on_message(l)
+        callback(m.decode(), m)
         l.timer.stop()
         l.timer.stop()
         in_hold = l.timer.queue[0]
         in_hold = l.timer.queue[0]
         self.assertEqual(len(in_hold), 3)
         self.assertEqual(len(in_hold), 3)
@@ -583,24 +635,34 @@ class test_Consumer(Case):
 
 
     def test_reset_pidbox_node(self):
     def test_reset_pidbox_node(self):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
-        l.pidbox_node = Mock()
-        chan = l.pidbox_node.channel = Mock()
+        con = find_step(l, consumer.Control).box
+        con.node = Mock()
+        chan = con.node.channel = Mock()
         l.connection = Mock()
         l.connection = Mock()
         chan.close.side_effect = socket.error('foo')
         chan.close.side_effect = socket.error('foo')
         l.connection_errors = (socket.error, )
         l.connection_errors = (socket.error, )
-        l.reset_pidbox_node()
+        con.reset()
         chan.close.assert_called_with()
         chan.close.assert_called_with()
 
 
     def test_reset_pidbox_node_green(self):
     def test_reset_pidbox_node_green(self):
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
-        l.pool = Mock()
-        l.pool.is_green = True
-        l.reset_pidbox_node()
-        l.pool.spawn_n.assert_called_with(l._green_pidbox_node)
+        from celery.worker.pidbox import gPidbox
+        pool = Mock()
+        pool.is_green = True
+        l = MyKombuConsumer(self.ready_queue, timer=self.timer, pool=pool)
+        con = find_step(l, consumer.Control)
+        self.assertIsInstance(con.box, gPidbox)
+        con.start(l)
+        l.pool.spawn_n.assert_called_with(
+            con.box.loop, l,
+        )
 
 
     def test__green_pidbox_node(self):
     def test__green_pidbox_node(self):
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
-        l.pidbox_node = Mock()
+        pool = Mock()
+        pool.is_green = True
+        l = MyKombuConsumer(self.ready_queue, timer=self.timer, pool=pool)
+        l.node = Mock()
+        controller = find_step(l, consumer.Control)
+        box = controller.box
 
 
         class BConsumer(Mock):
         class BConsumer(Mock):
 
 
@@ -611,7 +673,7 @@ class test_Consumer(Case):
             def __exit__(self, *exc_info):
             def __exit__(self, *exc_info):
                 self.cancel()
                 self.cancel()
 
 
-        l.pidbox_node.listen = BConsumer()
+        controller.box.node.listen = BConsumer()
         connections = []
         connections = []
 
 
         class Connection(object):
         class Connection(object):
@@ -640,25 +702,26 @@ class test_Consumer(Case):
                     self.calls += 1
                     self.calls += 1
                     raise socket.timeout()
                     raise socket.timeout()
                 self.obj.connection = None
                 self.obj.connection = None
-                self.obj._pidbox_node_shutdown.set()
+                controller.box._node_shutdown.set()
 
 
             def close(self):
             def close(self):
                 self.closed = True
                 self.closed = True
 
 
         l.connection = Mock()
         l.connection = Mock()
-        l._open_connection = lambda: Connection(obj=l)
-        l._green_pidbox_node()
+        l.connect = lambda: Connection(obj=l)
+        controller = find_step(l, consumer.Control)
+        controller.box.loop(l)
 
 
-        l.pidbox_node.listen.assert_called_with(callback=l.on_control)
-        self.assertTrue(l.broadcast_consumer)
-        l.broadcast_consumer.consume.assert_called_with()
+        self.assertTrue(controller.box.node.listen.called)
+        self.assertTrue(controller.box.consumer)
+        controller.box.consumer.consume.assert_called_with()
 
 
         self.assertIsNone(l.connection)
         self.assertIsNone(l.connection)
         self.assertTrue(connections[0].closed)
         self.assertTrue(connections[0].closed)
 
 
     @patch('kombu.connection.Connection._establish_connection')
     @patch('kombu.connection.Connection._establish_connection')
     @patch('kombu.utils.sleep')
     @patch('kombu.utils.sleep')
-    def test_open_connection_errback(self, sleep, connect):
+    def test_connect_errback(self, sleep, connect):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         from kombu.transport.memory import Transport
         from kombu.transport.memory import Transport
         Transport.connection_errors = (StdChannelError, )
         Transport.connection_errors = (StdChannelError, )
@@ -668,17 +731,18 @@ class test_Consumer(Case):
                 return
                 return
             raise StdChannelError()
             raise StdChannelError()
         connect.side_effect = effect
         connect.side_effect = effect
-        l._open_connection()
+        l.connect()
         connect.assert_called_with()
         connect.assert_called_with()
 
 
     def test_stop_pidbox_node(self):
     def test_stop_pidbox_node(self):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
-        l._pidbox_node_stopped = Event()
-        l._pidbox_node_shutdown = Event()
-        l._pidbox_node_stopped.set()
-        l.stop_pidbox_node()
+        cont = find_step(l, consumer.Control)
+        cont._node_stopped = Event()
+        cont._node_shutdown = Event()
+        cont._node_stopped.set()
+        cont.stop(l)
 
 
-    def test_start__consume_messages(self):
+    def test_start__loop(self):
 
 
         class _QoS(object):
         class _QoS(object):
             prev = 3
             prev = 3
@@ -703,18 +767,17 @@ class test_Consumer(Case):
         l.connection = Connection()
         l.connection = Connection()
         l.iterations = 0
         l.iterations = 0
 
 
-        def raises_KeyError(limit=None):
+        def raises_KeyError(*args, **kwargs):
             l.iterations += 1
             l.iterations += 1
             if l.qos.prev != l.qos.value:
             if l.qos.prev != l.qos.value:
                 l.qos.update()
                 l.qos.update()
             if l.iterations >= 2:
             if l.iterations >= 2:
                 raise KeyError('foo')
                 raise KeyError('foo')
 
 
-        l.consume_messages = raises_KeyError
+        l.loop = raises_KeyError
         with self.assertRaises(KeyError):
         with self.assertRaises(KeyError):
             l.start()
             l.start()
-        self.assertTrue(init_callback.call_count)
-        self.assertEqual(l.iterations, 1)
+        self.assertEqual(l.iterations, 2)
         self.assertEqual(l.qos.prev, l.qos.value)
         self.assertEqual(l.qos.prev, l.qos.value)
 
 
         init_callback.reset_mock()
         init_callback.reset_mock()
@@ -724,25 +787,25 @@ class test_Consumer(Case):
         l.task_consumer = Mock()
         l.task_consumer = Mock()
         l.broadcast_consumer = Mock()
         l.broadcast_consumer = Mock()
         l.connection = Connection()
         l.connection = Connection()
-        l.consume_messages = Mock(side_effect=socket.error('foo'))
+        l.loop = Mock(side_effect=socket.error('foo'))
         with self.assertRaises(socket.error):
         with self.assertRaises(socket.error):
             l.start()
             l.start()
-        self.assertTrue(init_callback.call_count)
-        self.assertTrue(l.consume_messages.call_count)
+        self.assertTrue(l.loop.call_count)
 
 
     def test_reset_connection_with_no_node(self):
     def test_reset_connection_with_no_node(self):
-        l = BlockingConsumer(self.ready_queue, timer=self.timer)
+        l = Consumer(self.ready_queue, timer=self.timer)
+        l.steps.pop()
         self.assertEqual(None, l.pool)
         self.assertEqual(None, l.pool)
-        l.reset_connection()
+        l.namespace.start(l)
 
 
     def test_on_task_revoked(self):
     def test_on_task_revoked(self):
-        l = BlockingConsumer(self.ready_queue, timer=self.timer)
+        l = Consumer(self.ready_queue, timer=self.timer)
         task = Mock()
         task = Mock()
         task.revoked.return_value = True
         task.revoked.return_value = True
         l.on_task(task)
         l.on_task(task)
 
 
     def test_on_task_no_events(self):
     def test_on_task_no_events(self):
-        l = BlockingConsumer(self.ready_queue, timer=self.timer)
+        l = Consumer(self.ready_queue, timer=self.timer)
         task = Mock()
         task = Mock()
         task.revoked.return_value = False
         task.revoked.return_value = False
         l.event_dispatcher = Mock()
         l.event_dispatcher = Mock()
@@ -778,7 +841,7 @@ class test_WorkController(AppCase):
     def test_use_pidfile(self, create_pidlock):
     def test_use_pidfile(self, create_pidlock):
         create_pidlock.return_value = Mock()
         create_pidlock.return_value = Mock()
         worker = self.create_worker(pidfile='pidfilelockfilepid')
         worker = self.create_worker(pidfile='pidfilelockfilepid')
-        worker.components = []
+        worker.steps = []
         worker.start()
         worker.start()
         self.assertTrue(create_pidlock.called)
         self.assertTrue(create_pidlock.called)
         worker.stop()
         worker.stop()
@@ -825,12 +888,12 @@ class test_WorkController(AppCase):
         self.assertTrue(worker.pool)
         self.assertTrue(worker.pool)
         self.assertTrue(worker.consumer)
         self.assertTrue(worker.consumer)
         self.assertTrue(worker.mediator)
         self.assertTrue(worker.mediator)
-        self.assertTrue(worker.components)
+        self.assertTrue(worker.steps)
 
 
     def test_with_embedded_celerybeat(self):
     def test_with_embedded_celerybeat(self):
         worker = WorkController(concurrency=1, loglevel=0, beat=True)
         worker = WorkController(concurrency=1, loglevel=0, beat=True)
         self.assertTrue(worker.beat)
         self.assertTrue(worker.beat)
-        self.assertIn(worker.beat, [w.obj for w in worker.components])
+        self.assertIn(worker.beat, [w.obj for w in worker.steps])
 
 
     def test_with_autoscaler(self):
     def test_with_autoscaler(self):
         worker = self.create_worker(autoscale=[10, 3], send_events=False,
         worker = self.create_worker(autoscale=[10, 3], send_events=False,
@@ -892,7 +955,7 @@ class test_WorkController(AppCase):
         m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
         m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
                            kwargs={})
                            kwargs={})
         task = Request.from_message(m, m.decode())
         task = Request.from_message(m, m.decode())
-        worker.components = []
+        worker.steps = []
         worker.namespace.state = RUN
         worker.namespace.state = RUN
         with self.assertRaises(KeyboardInterrupt):
         with self.assertRaises(KeyboardInterrupt):
             worker.process_task(task)
             worker.process_task(task)
@@ -906,7 +969,7 @@ class test_WorkController(AppCase):
         m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
         m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
                            kwargs={})
                            kwargs={})
         task = Request.from_message(m, m.decode())
         task = Request.from_message(m, m.decode())
-        worker.components = []
+        worker.steps = []
         worker.namespace.state = RUN
         worker.namespace.state = RUN
         with self.assertRaises(SystemExit):
         with self.assertRaises(SystemExit):
             worker.process_task(task)
             worker.process_task(task)
@@ -925,17 +988,18 @@ class test_WorkController(AppCase):
 
 
     def test_start_catches_base_exceptions(self):
     def test_start_catches_base_exceptions(self):
         worker1 = self.create_worker()
         worker1 = self.create_worker()
-        stc = Mock()
+        stc = MockStep()
         stc.start.side_effect = SystemTerminate()
         stc.start.side_effect = SystemTerminate()
-        worker1.components = [stc]
+        worker1.steps = [stc]
         worker1.start()
         worker1.start()
+        stc.start.assert_called_with(worker1)
         self.assertTrue(stc.terminate.call_count)
         self.assertTrue(stc.terminate.call_count)
 
 
         worker2 = self.create_worker()
         worker2 = self.create_worker()
-        sec = Mock()
+        sec = MockStep()
         sec.start.side_effect = SystemExit()
         sec.start.side_effect = SystemExit()
         sec.terminate = None
         sec.terminate = None
-        worker2.components = [sec]
+        worker2.steps = [sec]
         worker2.start()
         worker2.start()
         self.assertTrue(sec.stop.call_count)
         self.assertTrue(sec.stop.call_count)
 
 
@@ -988,18 +1052,18 @@ class test_WorkController(AppCase):
     def test_start__stop(self):
     def test_start__stop(self):
         worker = self.worker
         worker = self.worker
         worker.namespace.shutdown_complete.set()
         worker.namespace.shutdown_complete.set()
-        worker.components = [StartStopComponent(self) for _ in range(4)]
+        worker.steps = [MockStep(StartStopStep(self)) for _ in range(4)]
         worker.namespace.state = RUN
         worker.namespace.state = RUN
         worker.namespace.started = 4
         worker.namespace.started = 4
-        for w in worker.components:
+        for w in worker.steps:
             w.start = Mock()
             w.start = Mock()
             w.stop = Mock()
             w.stop = Mock()
 
 
         worker.start()
         worker.start()
-        for w in worker.components:
+        for w in worker.steps:
             self.assertTrue(w.start.call_count)
             self.assertTrue(w.start.call_count)
         worker.stop()
         worker.stop()
-        for w in worker.components:
+        for w in worker.steps:
             self.assertTrue(w.stop.call_count)
             self.assertTrue(w.stop.call_count)
 
 
         # Doesn't close pool if no pool.
         # Doesn't close pool if no pool.
@@ -1008,15 +1072,15 @@ class test_WorkController(AppCase):
         worker.stop()
         worker.stop()
 
 
         # test that stop of None is not attempted
         # test that stop of None is not attempted
-        worker.components[-1] = None
+        worker.steps[-1] = None
         worker.start()
         worker.start()
         worker.stop()
         worker.stop()
 
 
-    def test_component_raises(self):
+    def test_step_raises(self):
         worker = self.worker
         worker = self.worker
-        comp = Mock()
-        worker.components = [comp]
-        comp.start.side_effect = TypeError()
+        step = Mock()
+        worker.steps = [step]
+        step.start.side_effect = TypeError()
         worker.stop = Mock()
         worker.stop = Mock()
         worker.start()
         worker.start()
         worker.stop.assert_called_with()
         worker.stop.assert_called_with()
@@ -1029,16 +1093,15 @@ class test_WorkController(AppCase):
         worker.namespace.shutdown_complete.set()
         worker.namespace.shutdown_complete.set()
         worker.namespace.started = 5
         worker.namespace.started = 5
         worker.namespace.state = RUN
         worker.namespace.state = RUN
-        worker.components = [Mock(), Mock(), Mock(), Mock(), Mock()]
-
+        worker.steps = [MockStep() for _ in range(5)]
         worker.start()
         worker.start()
-        for w in worker.components[:3]:
+        for w in worker.steps[:3]:
             self.assertTrue(w.start.call_count)
             self.assertTrue(w.start.call_count)
-        self.assertTrue(worker.namespace.started, len(worker.components))
+        self.assertTrue(worker.namespace.started, len(worker.steps))
         self.assertEqual(worker.namespace.state, RUN)
         self.assertEqual(worker.namespace.state, RUN)
         worker.terminate()
         worker.terminate()
-        for component in worker.components:
-            self.assertTrue(component.terminate.call_count)
+        for step in worker.steps:
+            self.assertTrue(step.terminate.call_count)
 
 
     def test_Queues_pool_not_rlimit_safe(self):
     def test_Queues_pool_not_rlimit_safe(self):
         w = Mock()
         w = Mock()
@@ -1052,9 +1115,9 @@ class test_WorkController(AppCase):
         Queues(w).create(w)
         Queues(w).create(w)
         self.assertIs(w.ready_queue.put, w.process_task)
         self.assertIs(w.ready_queue.put, w.process_task)
 
 
-    def test_EvLoop_crate(self):
+    def test_Hub_crate(self):
         w = Mock()
         w = Mock()
-        x = EvLoop(w)
+        x = Hub(w)
         hub = x.create(w)
         hub = x.create(w)
         self.assertTrue(w.timer.max_interval)
         self.assertTrue(w.timer.max_interval)
         self.assertIs(w.hub, hub)
         self.assertIs(w.hub, hub)

+ 4 - 12
celery/utils/imports.py

@@ -24,18 +24,10 @@ class NotAPackage(Exception):
     pass
     pass
 
 
 
 
-if sys.version_info >= (3, 3):  # pragma: no cover
-
-    def qualname(obj):
-        return obj.__qualname__
-
-else:
-
-    def qualname(obj):  # noqa
-        if not hasattr(obj, '__name__') and hasattr(obj, '__class__'):
-            return qualname(obj.__class__)
-
-        return '%s.%s' % (obj.__module__, obj.__name__)
+def qualname(obj):  # noqa
+    if not hasattr(obj, '__name__') and hasattr(obj, '__class__'):
+        obj = obj.__class__
+    return '%s.%s' % (obj.__module__, obj.__name__)
 
 
 
 
 def instantiate(name, *args, **kwargs):
 def instantiate(name, *args, **kwargs):

+ 2 - 0
celery/utils/text.py

@@ -13,6 +13,8 @@ from textwrap import fill
 
 
 from pprint import pformat
 from pprint import pformat
 
 
+from kombu.utils.encoding import safe_repr
+
 
 
 def dedent_initial(s, n=4):
 def dedent_initial(s, n=4):
     return s[n:] if s[:n] == ' ' * n else s
     return s[n:] if s[:n] == ' ' * n else s

+ 11 - 0
celery/utils/threads.py

@@ -9,10 +9,13 @@
 from __future__ import absolute_import, print_function
 from __future__ import absolute_import, print_function
 
 
 import os
 import os
+import socket
 import sys
 import sys
 import threading
 import threading
 import traceback
 import traceback
 
 
+from contextlib import contextmanager
+
 from celery.local import Proxy
 from celery.local import Proxy
 from celery.utils.compat import THREAD_TIMEOUT_MAX
 from celery.utils.compat import THREAD_TIMEOUT_MAX
 
 
@@ -284,6 +287,14 @@ class LocalManager(object):
             self.__class__.__name__, len(self.locals))
             self.__class__.__name__, len(self.locals))
 
 
 
 
+@contextmanager
+def default_socket_timeout(timeout):
+    prev = socket.getdefaulttimeout()
+    socket.setdefaulttimeout(timeout)
+    yield
+    socket.setdefaulttimeout(prev)
+
+
 class _FastLocalStack(threading.local):
 class _FastLocalStack(threading.local):
 
 
     def __init__(self):
     def __init__(self):

+ 35 - 26
celery/worker/__init__.py

@@ -6,7 +6,7 @@
     :class:`WorkController` can be used to instantiate in-process workers.
     :class:`WorkController` can be used to instantiate in-process workers.
 
 
     The worker consists of several components, all managed by boot-steps
     The worker consists of several components, all managed by boot-steps
-    (mod:`celery.worker.bootsteps`).
+    (mod:`celery.bootsteps`).
 
 
 """
 """
 from __future__ import absolute_import
 from __future__ import absolute_import
@@ -19,6 +19,7 @@ from billiard import cpu_count
 from kombu.syn import detect_environment
 from kombu.syn import detect_environment
 from kombu.utils.finalize import Finalize
 from kombu.utils.finalize import Finalize
 
 
+from celery import bootsteps
 from celery import concurrency as _concurrency
 from celery import concurrency as _concurrency
 from celery import platforms
 from celery import platforms
 from celery import signals
 from celery import signals
@@ -31,7 +32,6 @@ from celery.utils import worker_direct
 from celery.utils.imports import reload_from_cwd
 from celery.utils.imports import reload_from_cwd
 from celery.utils.log import mlevel, worker_logger as logger
 from celery.utils.log import mlevel, worker_logger as logger
 
 
-from . import bootsteps
 from . import state
 from . import state
 
 
 UNKNOWN_QUEUE = """\
 UNKNOWN_QUEUE = """\
@@ -43,24 +43,6 @@ enable the CELERY_CREATE_MISSING_QUEUES setting.
 """
 """
 
 
 
 
-class Namespace(bootsteps.Namespace):
-    """This is the boot-step namespace of the :class:`WorkController`.
-
-    It loads modules from :setting:`CELERYD_BOOT_STEPS`, and its
-    own set of built-in boot-step modules.
-
-    """
-    name = 'worker'
-    builtin_boot_steps = ('celery.worker.components',
-                          'celery.worker.autoscale',
-                          'celery.worker.autoreload',
-                          'celery.worker.consumer',
-                          'celery.worker.mediator')
-
-    def modules(self):
-        return self.builtin_boot_steps + self.app.conf.CELERYD_BOOT_STEPS
-
-
 class WorkController(configurated):
 class WorkController(configurated):
     """Unmanaged worker instance."""
     """Unmanaged worker instance."""
     app = None
     app = None
@@ -90,6 +72,28 @@ class WorkController(configurated):
 
 
     pidlock = None
     pidlock = None
 
 
+    class Namespace(bootsteps.Namespace):
+        """This is the boot-step namespace of the :class:`WorkController`.
+
+        It loads modules from :setting:`CELERYD_BOOT_STEPS`, and its
+        own set of built-in boot-step modules.
+
+        """
+        name = 'Worker'
+        default_steps = set([
+            'celery.worker.components:Hub',
+            'celery.worker.components:Queues',
+            'celery.worker.components:Pool',
+            'celery.worker.components:Beat',
+            'celery.worker.components:Timers',
+            'celery.worker.components:StateDB',
+            'celery.worker.components:Consumer',
+            'celery.worker.autoscale:WorkerComponent',
+            'celery.worker.autoreload:WorkerComponent',
+            'celery.worker.mediator:WorkerComponent',
+
+        ])
+
     def __init__(self, app=None, hostname=None, **kwargs):
     def __init__(self, app=None, hostname=None, **kwargs):
         self.app = app_or_default(app or self.app)
         self.app = app_or_default(app or self.app)
         self.hostname = hostname or socket.gethostname()
         self.hostname = hostname or socket.gethostname()
@@ -117,18 +121,23 @@ class WorkController(configurated):
         self.loglevel = mlevel(self.loglevel)
         self.loglevel = mlevel(self.loglevel)
         self.ready_callback = ready_callback or self.on_consumer_ready
         self.ready_callback = ready_callback or self.on_consumer_ready
         self.use_eventloop = self.should_use_eventloop()
         self.use_eventloop = self.should_use_eventloop()
+        self.options = kwargs
 
 
         signals.worker_init.send(sender=self)
         signals.worker_init.send(sender=self)
 
 
         # Initialize boot steps
         # Initialize boot steps
         self.pool_cls = _concurrency.get_implementation(self.pool_cls)
         self.pool_cls = _concurrency.get_implementation(self.pool_cls)
-        self.components = []
-        self.namespace = Namespace(app=self.app,
-                                   on_start=self.on_start,
-                                   on_close=self.on_close,
-                                   on_stopped=self.on_stopped)
+        self.steps = []
+        self.on_init_namespace()
+        self.namespace = self.Namespace(app=self.app,
+                                        on_start=self.on_start,
+                                        on_close=self.on_close,
+                                        on_stopped=self.on_stopped)
         self.namespace.apply(self, **kwargs)
         self.namespace.apply(self, **kwargs)
 
 
+    def on_init_namespace(self):
+        pass
+
     def on_before_init(self, **kwargs):
     def on_before_init(self, **kwargs):
         pass
         pass
 
 
@@ -144,7 +153,7 @@ class WorkController(configurated):
 
 
     def on_stopped(self):
     def on_stopped(self):
         self.timer.stop()
         self.timer.stop()
-        self.consumer.close_connection()
+        self.consumer.shutdown()
 
 
         if self.pidlock:
         if self.pidlock:
             self.pidlock.release()
             self.pidlock.release()

+ 4 - 4
celery/worker/autoreload.py

@@ -18,12 +18,13 @@ from threading import Event
 
 
 from kombu.utils import eventio
 from kombu.utils import eventio
 
 
+from celery import bootsteps
 from celery.platforms import ignore_errno
 from celery.platforms import ignore_errno
 from celery.utils.imports import module_file
 from celery.utils.imports import module_file
 from celery.utils.log import get_logger
 from celery.utils.log import get_logger
 from celery.utils.threads import bgThread
 from celery.utils.threads import bgThread
 
 
-from .bootsteps import StartStopComponent
+from .components import Pool
 
 
 try:                        # pragma: no cover
 try:                        # pragma: no cover
     import pyinotify
     import pyinotify
@@ -35,9 +36,8 @@ except ImportError:         # pragma: no cover
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
-class WorkerComponent(StartStopComponent):
-    name = 'worker.autoreloader'
-    requires = ('pool', )
+class WorkerComponent(bootsteps.StartStopStep):
+    requires = (Pool, )
 
 
     def __init__(self, w, autoreload=None, **kwargs):
     def __init__(self, w, autoreload=None, **kwargs):
         self.enabled = w.autoreload = autoreload
         self.enabled = w.autoreload = autoreload

+ 4 - 4
celery/worker/autoscale.py

@@ -18,20 +18,20 @@ import threading
 from functools import partial
 from functools import partial
 from time import sleep, time
 from time import sleep, time
 
 
+from celery import bootsteps
 from celery.utils.log import get_logger
 from celery.utils.log import get_logger
 from celery.utils.threads import bgThread
 from celery.utils.threads import bgThread
 
 
 from . import state
 from . import state
-from .bootsteps import StartStopComponent
+from .components import Pool
 from .hub import DummyLock
 from .hub import DummyLock
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 debug, info, error = logger.debug, logger.info, logger.error
 debug, info, error = logger.debug, logger.info, logger.error
 
 
 
 
-class WorkerComponent(StartStopComponent):
-    name = 'worker.autoscaler'
-    requires = ('pool', )
+class WorkerComponent(bootsteps.StartStopStep):
+    requires = (Pool, )
 
 
     def __init__(self, w, **kwargs):
     def __init__(self, w, **kwargs):
         self.enabled = w.autoscale
         self.enabled = w.autoscale

+ 0 - 292
celery/worker/bootsteps.py

@@ -1,292 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-    celery.worker.bootsteps
-    ~~~~~~~~~~~~~~~~~~~~~~~
-
-    The boot-step components.
-
-"""
-from __future__ import absolute_import
-
-import socket
-
-from collections import defaultdict
-from importlib import import_module
-from threading import Event
-
-from celery.datastructures import DependencyGraph
-from celery.utils.imports import instantiate, qualname
-from celery.utils.log import get_logger
-
-try:
-    from greenlet import GreenletExit
-    IGNORE_ERRORS = (GreenletExit, )
-except ImportError:  # pragma: no cover
-    IGNORE_ERRORS = ()
-
-#: Default socket timeout at shutdown.
-SHUTDOWN_SOCKET_TIMEOUT = 5.0
-
-#: States
-RUN = 0x1
-CLOSE = 0x2
-TERMINATE = 0x3
-
-logger = get_logger(__name__)
-
-
-class Namespace(object):
-    """A namespace containing components.
-
-    Every component must belong to a namespace.
-
-    When component classes are created they are added to the
-    mapping of unclaimed components.  The components will be
-    claimed when the namespace they belong to is created.
-
-    :keyword name: Set the name of this namespace.
-    :keyword app: Set the Celery app for this namespace.
-
-    """
-    name = None
-    state = None
-    started = 0
-
-    _unclaimed = defaultdict(dict)
-
-    def __init__(self, name=None, app=None, on_start=None,
-            on_close=None, on_stopped=None):
-        self.app = app
-        self.name = name or self.name
-        self.on_start = on_start
-        self.on_close = on_close
-        self.on_stopped = on_stopped
-        self.services = []
-        self.shutdown_complete = Event()
-
-    def start(self, parent):
-        self.state = RUN
-        if self.on_start:
-            self.on_start()
-        for i, component in enumerate(parent.components):
-            if component:
-                logger.debug('Starting %s...', qualname(component))
-                self.started = i + 1
-                component.start(parent)
-                logger.debug('%s OK!', qualname(component))
-
-    def close(self, parent):
-        if self.on_close:
-            self.on_close()
-        for component in parent.components:
-            try:
-                close = component.close
-            except AttributeError:
-                pass
-            else:
-                close(parent)
-
-    def stop(self, parent, terminate=False):
-        what = 'Terminating' if terminate else 'Stopping'
-        socket_timeout = socket.getdefaulttimeout()
-        socket.setdefaulttimeout(SHUTDOWN_SOCKET_TIMEOUT)  # Issue 975
-
-        if self.state in (CLOSE, TERMINATE):
-            return
-
-        self.close(parent)
-
-        if self.state != RUN or self.started != len(parent.components):
-            # Not fully started, can safely exit.
-            self.state = TERMINATE
-            self.shutdown_complete.set()
-            return
-        self.state = CLOSE
-
-        for component in reversed(parent.components):
-            if component:
-                logger.debug('%s %s...', what, qualname(component))
-                (component.terminate if terminate else component.stop)(parent)
-
-        if self.on_stopped:
-            self.on_stopped()
-        self.state = TERMINATE
-        socket.setdefaulttimeout(socket_timeout)
-        self.shutdown_complete.set()
-
-    def join(self, timeout=None):
-        try:
-            # Will only get here if running green,
-            # makes sure all greenthreads have exited.
-            self.shutdown_complete.wait(timeout=timeout)
-        except IGNORE_ERRORS:
-            pass
-
-    def modules(self):
-        """Subclasses can override this to return a
-        list of modules to import before components are claimed."""
-        return []
-
-    def load_modules(self):
-        """Will load the component modules this namespace depends on."""
-        for m in self.modules():
-            self.import_module(m)
-
-    def apply(self, parent, **kwargs):
-        """Apply the components in this namespace to an object.
-
-        This will apply the ``__init__`` and ``include`` methods
-        of each components with the object as argument.
-
-        For ``StartStopComponents`` the services created
-        will also be added the the objects ``components`` attribute.
-
-        """
-        self._debug('Loading modules.')
-        self.load_modules()
-        self._debug('Claiming components.')
-        self.components = self._claim()
-        self._debug('Building boot step graph.')
-        self.boot_steps = [self.bind_component(name, parent, **kwargs)
-                                for name in self._finalize_boot_steps()]
-        self._debug('New boot order: {%s}',
-                ', '.join(c.name for c in self.boot_steps))
-
-        for component in self.boot_steps:
-            component.include(parent)
-        return self
-
-    def bind_component(self, name, parent, **kwargs):
-        """Bind component to parent object and this namespace."""
-        comp = self[name](parent, **kwargs)
-        comp.namespace = self
-        return comp
-
-    def import_module(self, module):
-        return import_module(module)
-
-    def __getitem__(self, name):
-        return self.components[name]
-
-    def _find_last(self):
-        for C in self.components.itervalues():
-            if C.last:
-                return C
-
-    def _finalize_boot_steps(self):
-        G = self.graph = DependencyGraph((C.name, C.requires)
-                            for C in self.components.itervalues())
-        last = self._find_last()
-        if last:
-            for obj in G:
-                if obj != last.name:
-                    G.add_edge(last.name, obj)
-        return G.topsort()
-
-    def _claim(self):
-        return self._unclaimed[self.name]
-
-    def _debug(self, msg, *args):
-        return logger.debug('[%s] ' + msg,
-                            *(self.name.capitalize(), ) + args)
-
-
-class ComponentType(type):
-    """Metaclass for components."""
-
-    def __new__(cls, name, bases, attrs):
-        abstract = attrs.pop('abstract', False)
-        if not abstract:
-            try:
-                cname = attrs['name']
-            except KeyError:
-                raise NotImplementedError('Components must be named')
-            namespace = attrs.get('namespace', None)
-            if not namespace:
-                attrs['namespace'], _, attrs['name'] = cname.partition('.')
-        cls = super(ComponentType, cls).__new__(cls, name, bases, attrs)
-        if not abstract:
-            Namespace._unclaimed[cls.namespace][cls.name] = cls
-        return cls
-
-
-class Component(object):
-    """A component.
-
-    The :meth:`__init__` method is called when the component
-    is bound to a parent object, and can as such be used
-    to initialize attributes in the parent object at
-    parent instantiation-time.
-
-    """
-    __metaclass__ = ComponentType
-
-    #: The name of the component, or the namespace
-    #: and the name of the component separated by dot.
-    name = None
-
-    #: List of component names this component depends on.
-    #: Note that the dependencies must be in the same namespace.
-    requires = ()
-
-    #: can be used to specify the namespace,
-    #: if the name does not include it.
-    namespace = None
-
-    #: if set the component will not be registered,
-    #: but can be used as a component base class.
-    abstract = True
-
-    #: Optional obj created by the :meth:`create` method.
-    #: This is used by StartStopComponents to keep the
-    #: original service object.
-    obj = None
-
-    #: This flag is reserved for the workers Consumer,
-    #: since it is required to always be started last.
-    #: There can only be one object marked with lsat
-    #: in every namespace.
-    last = False
-
-    #: This provides the default for :meth:`include_if`.
-    enabled = True
-
-    def __init__(self, parent, **kwargs):
-        pass
-
-    def create(self, parent):
-        """Create the component."""
-        pass
-
-    def include_if(self, parent):
-        """An optional predicate that decided whether this
-        component should be created."""
-        return self.enabled
-
-    def instantiate(self, qualname, *args, **kwargs):
-        return instantiate(qualname, *args, **kwargs)
-
-    def include(self, parent):
-        if self.include_if(parent):
-            self.obj = self.create(parent)
-            return True
-
-
-class StartStopComponent(Component):
-    abstract = True
-
-    def start(self, parent):
-        return self.obj.start()
-
-    def stop(self, parent):
-        return self.obj.stop()
-
-    def close(self, parent):
-        pass
-
-    def terminate(self, parent):
-        self.stop(parent)
-
-    def include(self, parent):
-        if super(StartStopComponent, self).include(parent):
-            parent.components.append(self)

+ 72 - 58
celery/worker/components.py

@@ -15,16 +15,55 @@ from functools import partial
 
 
 from billiard.exceptions import WorkerLostError
 from billiard.exceptions import WorkerLostError
 
 
+from celery import bootsteps
 from celery.utils.log import worker_logger as logger
 from celery.utils.log import worker_logger as logger
 from celery.utils.timer2 import Schedule
 from celery.utils.timer2 import Schedule
 
 
-from . import bootsteps
+from . import hub
 from .buckets import TaskBucket, FastQueue
 from .buckets import TaskBucket, FastQueue
-from .hub import Hub, BoundedSemaphore
 
 
 
 
-class Pool(bootsteps.StartStopComponent):
-    """The pool component.
+class Hub(bootsteps.StartStopStep):
+
+    def __init__(self, w, **kwargs):
+        w.hub = None
+
+    def include_if(self, w):
+        return w.use_eventloop
+
+    def create(self, w):
+        w.timer = Schedule(max_interval=10)
+        w.hub = hub.Hub(w.timer)
+        return w.hub
+
+
+class Queues(bootsteps.Step):
+    """This step initializes the internal queues
+    used by the worker."""
+    requires = (Hub, )
+
+    def create(self, w):
+        w.start_mediator = True
+        if not w.pool_cls.rlimit_safe:
+            w.disable_rate_limits = True
+        if w.disable_rate_limits:
+            w.ready_queue = FastQueue()
+            if w.use_eventloop:
+                w.start_mediator = False
+                if w.pool_putlocks and w.pool_cls.uses_semaphore:
+                    w.ready_queue.put = w.process_task_sem
+                else:
+                    w.ready_queue.put = w.process_task
+            elif not w.pool_cls.requires_mediator:
+                # just send task directly to pool, skip the mediator.
+                w.ready_queue.put = w.process_task
+                w.start_mediator = False
+        else:
+            w.ready_queue = TaskBucket(task_registry=w.app.tasks)
+
+
+class Pool(bootsteps.StartStopStep):
+    """The pool step.
 
 
     Describes how to initialize the worker pool, and starts and stops
     Describes how to initialize the worker pool, and starts and stops
     the pool during worker startup/shutdown.
     the pool during worker startup/shutdown.
@@ -37,8 +76,7 @@ class Pool(bootsteps.StartStopComponent):
         * min_concurrency
         * min_concurrency
 
 
     """
     """
-    name = 'worker.pool'
-    requires = ('queues', )
+    requires = (Queues, )
 
 
     def __init__(self, w, autoscale=None, autoreload=None,
     def __init__(self, w, autoscale=None, autoreload=None,
             no_execv=False, **kwargs):
             no_execv=False, **kwargs):
@@ -115,7 +153,7 @@ class Pool(bootsteps.StartStopComponent):
         procs = w.min_concurrency
         procs = w.min_concurrency
         forking_enable = not threaded or (w.no_execv or not w.force_execv)
         forking_enable = not threaded or (w.no_execv or not w.force_execv)
         if not threaded:
         if not threaded:
-            semaphore = w.semaphore = BoundedSemaphore(procs)
+            semaphore = w.semaphore = hub.BoundedSemaphore(procs)
             w._quick_acquire = w.semaphore.acquire
             w._quick_acquire = w.semaphore.acquire
             w._quick_release = w.semaphore.release
             w._quick_release = w.semaphore.release
             max_restarts = 100
             max_restarts = 100
@@ -137,14 +175,13 @@ class Pool(bootsteps.StartStopComponent):
         return pool
         return pool
 
 
 
 
-class Beat(bootsteps.StartStopComponent):
-    """Component used to embed a celerybeat process.
+class Beat(bootsteps.StartStopStep):
+    """Step used to embed a celerybeat process.
 
 
     This will only be enabled if the ``beat``
     This will only be enabled if the ``beat``
     argument is set.
     argument is set.
 
 
     """
     """
-    name = 'worker.beat'
 
 
     def __init__(self, w, beat=False, **kwargs):
     def __init__(self, w, beat=False, **kwargs):
         self.enabled = w.beat = beat
         self.enabled = w.beat = beat
@@ -158,51 +195,9 @@ class Beat(bootsteps.StartStopComponent):
         return b
         return b
 
 
 
 
-class Queues(bootsteps.Component):
-    """This component initializes the internal queues
-    used by the worker."""
-    name = 'worker.queues'
-    requires = ('ev', )
-
-    def create(self, w):
-        w.start_mediator = True
-        if not w.pool_cls.rlimit_safe:
-            w.disable_rate_limits = True
-        if w.disable_rate_limits:
-            w.ready_queue = FastQueue()
-            if w.use_eventloop:
-                w.start_mediator = False
-                if w.pool_putlocks and w.pool_cls.uses_semaphore:
-                    w.ready_queue.put = w.process_task_sem
-                else:
-                    w.ready_queue.put = w.process_task
-            elif not w.pool_cls.requires_mediator:
-                # just send task directly to pool, skip the mediator.
-                w.ready_queue.put = w.process_task
-                w.start_mediator = False
-        else:
-            w.ready_queue = TaskBucket(task_registry=w.app.tasks)
-
-
-class EvLoop(bootsteps.StartStopComponent):
-    name = 'worker.ev'
-
-    def __init__(self, w, **kwargs):
-        w.hub = None
-
-    def include_if(self, w):
-        return w.use_eventloop
-
-    def create(self, w):
-        w.timer = Schedule(max_interval=10)
-        hub = w.hub = Hub(w.timer)
-        return hub
-
-
-class Timers(bootsteps.Component):
-    """This component initializes the internal timers used by the worker."""
-    name = 'worker.timers'
-    requires = ('pool', )
+class Timers(bootsteps.Step):
+    """This step initializes the internal timers used by the worker."""
+    requires = (Pool, )
 
 
     def include_if(self, w):
     def include_if(self, w):
         return not w.use_eventloop
         return not w.use_eventloop
@@ -224,9 +219,8 @@ class Timers(bootsteps.Component):
         logger.debug('Timer wake-up! Next eta %s secs.', delay)
         logger.debug('Timer wake-up! Next eta %s secs.', delay)
 
 
 
 
-class StateDB(bootsteps.Component):
-    """This component sets up the workers state db if enabled."""
-    name = 'worker.state-db'
+class StateDB(bootsteps.Step):
+    """This step sets up the workers state db if enabled."""
 
 
     def __init__(self, w, **kwargs):
     def __init__(self, w, **kwargs):
         self.enabled = w.state_db
         self.enabled = w.state_db
@@ -235,3 +229,23 @@ class StateDB(bootsteps.Component):
     def create(self, w):
     def create(self, w):
         w._persistence = w.state.Persistent(w.state_db)
         w._persistence = w.state.Persistent(w.state_db)
         atexit.register(w._persistence.save)
         atexit.register(w._persistence.save)
+
+
+class Consumer(bootsteps.StartStopStep):
+    last = True
+
+    def create(self, w):
+        prefetch_count = w.concurrency * w.prefetch_multiplier
+        c = w.consumer = self.instantiate(w.consumer_cls,
+                w.ready_queue,
+                hostname=w.hostname,
+                send_events=w.send_events,
+                init_callback=w.ready_callback,
+                initial_prefetch_count=prefetch_count,
+                pool=w.pool,
+                timer=w.timer,
+                app=w.app,
+                controller=w,
+                hub=w.hub,
+                worker_options=w.options)
+        return c

+ 262 - 549
celery/worker/consumer.py

@@ -3,7 +3,7 @@
 celery.worker.consumer
 celery.worker.consumer
 ~~~~~~~~~~~~~~~~~~~~~~
 ~~~~~~~~~~~~~~~~~~~~~~
 
 
-This module contains the component responsible for consuming messages
+This module contains the components responsible for consuming messages
 from the broker, processing the messages and keeping the broker connections
 from the broker, processing the messages and keeping the broker connections
 up and running.
 up and running.
 
 
@@ -12,33 +12,45 @@ from __future__ import absolute_import
 
 
 import logging
 import logging
 import socket
 import socket
-import threading
 
 
-from time import sleep
-from Queue import Empty
-
-from kombu.common import QoS
+from kombu.common import QoS, ignore_errors
 from kombu.syn import _detect_environment
 from kombu.syn import _detect_environment
 from kombu.utils.encoding import safe_repr
 from kombu.utils.encoding import safe_repr
-from kombu.utils.eventio import READ, WRITE, ERR
 
 
+from celery import bootsteps
 from celery.app import app_or_default
 from celery.app import app_or_default
-from celery.datastructures import AttributeDict
-from celery.exceptions import InvalidTaskError, SystemTerminate
 from celery.task.trace import build_tracer
 from celery.task.trace import build_tracer
-from celery.utils import text
-from celery.utils import timer2
+from celery.utils.timer2 import default_timer, to_timestamp
 from celery.utils.functional import noop
 from celery.utils.functional import noop
 from celery.utils.log import get_logger
 from celery.utils.log import get_logger
+from celery.utils.text import truncate
 from celery.utils.timeutils import humanize_seconds, timezone
 from celery.utils.timeutils import humanize_seconds, timezone
 
 
-from . import state
-from .bootsteps import StartStopComponent, RUN, CLOSE
-from .control import Panel
-from .heartbeat import Heart
+from . import heartbeat, loops, pidbox
+from .state import task_reserved, maybe_shutdown
+
+CLOSE = bootsteps.CLOSE
+logger = get_logger(__name__)
+debug, info, warn, error, crit = (logger.debug, logger.info, logger.warn,
+                                  logger.error, logger.critical)
+
+CONNECTION_RETRY = """\
+consumer: Connection to broker lost. \
+Trying to re-establish the connection...\
+"""
+
+CONNECTION_RETRY_STEP = """\
+Trying again {when}...\
+"""
 
 
-#: Heartbeat check is called every heartbeat_seconds' / rate'.
-AMQHEARTBEAT_RATE = 2.0
+CONNECTION_ERROR = """\
+consumer: Cannot connect to %s: %s.
+%s
+"""
+
+CONNECTION_FAILOVER = """\
+Will retry using next failover.\
+"""
 
 
 UNKNOWN_FORMAT = """\
 UNKNOWN_FORMAT = """\
 Received and deleted unknown message. Wrong destination?!?
 Received and deleted unknown message. Wrong destination?!?
@@ -53,7 +65,7 @@ The message has been ignored and discarded.
 
 
 Did you remember to import the module containing this task?
 Did you remember to import the module containing this task?
 Or maybe you are using relative imports?
 Or maybe you are using relative imports?
-More: http://docs.celeryq.org/en/latest/userguide/tasks.html#names
+Please see http://bit.ly/gLye1c for more information.
 
 
 The full contents of the message body was:
 The full contents of the message body was:
 %s
 %s
@@ -64,8 +76,8 @@ INVALID_TASK_ERROR = """\
 Received invalid task message: %s
 Received invalid task message: %s
 The message has been ignored and discarded.
 The message has been ignored and discarded.
 
 
-Please ensure your message conforms to the task message format:
-http://docs.celeryq.org/en/latest/internals/protocol.html
+Please ensure your message conforms to the task
+message protocol as described here: http://bit.ly/hYj41y
 
 
 The full contents of the message body was:
 The full contents of the message body was:
 %s
 %s
@@ -76,112 +88,20 @@ body: {0} {{content_type:{1} content_encoding:{2} delivery_info:{3}}}\
 """
 """
 
 
 
 
-RETRY_CONNECTION = """\
-consumer: Connection to broker lost. \
-Trying to re-establish the connection...\
-"""
-
-CONNECTION_ERROR = """\
-consumer: Cannot connect to %s: %s.
-%s
-"""
-
-CONNECTION_RETRY = """\
-Trying again {when}...\
-"""
-
-CONNECTION_FAILOVER = """\
-Will retry using next failover.\
-"""
-
-task_reserved = state.task_reserved
-
-logger = get_logger(__name__)
-info, warn, error, crit = (logger.info, logger.warn,
-                           logger.error, logger.critical)
-
-
-def debug(msg, *args, **kwargs):
-    logger.debug('consumer: {0}'.format(msg), *args, **kwargs)
-
-
 def dump_body(m, body):
 def dump_body(m, body):
-    return '{0} ({1}b)'.format(text.truncate(safe_repr(body), 1024),
+    return '{0} ({1}b)'.format(truncate(safe_repr(body), 1024),
                                len(m.body))
                                len(m.body))
 
 
 
 
-class Component(StartStopComponent):
-    name = 'worker.consumer'
-    last = True
-
-    def Consumer(self, w):
-        return (w.consumer_cls or
-                Consumer if w.hub else BlockingConsumer)
-
-    def create(self, w):
-        prefetch_count = w.concurrency * w.prefetch_multiplier
-        c = w.consumer = self.instantiate(self.Consumer(w),
-                w.ready_queue,
-                hostname=w.hostname,
-                send_events=w.send_events,
-                init_callback=w.ready_callback,
-                initial_prefetch_count=prefetch_count,
-                pool=w.pool,
-                timer=w.timer,
-                app=w.app,
-                controller=w,
-                hub=w.hub)
-        return c
-
-
 class Consumer(object):
 class Consumer(object):
-    """Listen for messages received from the broker and
-    move them to the ready queue for task processing.
-
-    :param ready_queue: See :attr:`ready_queue`.
-    :param timer: See :attr:`timer`.
-
-    """
 
 
-    #: The queue that holds tasks ready for immediate processing.
+    #: Intra-queue for tasks ready to be handled
     ready_queue = None
     ready_queue = None
 
 
-    #: Enable/disable events.
-    send_events = False
-
-    #: Optional callback to be called when the connection is established.
-    #: Will only be called once, even if the connection is lost and
-    #: re-established.
+    #: Optional callback called the first time the worker
+    #: is ready to receive tasks.
     init_callback = None
     init_callback = None
 
 
-    #: The current hostname.  Defaults to the system hostname.
-    hostname = None
-
-    #: Initial QoS prefetch count for the task channel.
-    initial_prefetch_count = 0
-
-    #: A :class:`celery.events.EventDispatcher` for sending events.
-    event_dispatcher = None
-
-    #: The thread that sends event heartbeats at regular intervals.
-    #: The heartbeats are used by monitors to detect that a worker
-    #: went offline/disappeared.
-    heart = None
-
-    #: The broker connection.
-    connection = None
-
-    #: The consumer used to consume task messages.
-    task_consumer = None
-
-    #: The consumer used to consume broadcast commands.
-    broadcast_consumer = None
-
-    #: The process mailbox (kombu pidbox node).
-    pidbox_node = None
-    _pidbox_node_shutdown = None   # used for greenlets
-    _pidbox_node_stopped = None    # used for greenlets
-
     #: The current worker pool instance.
     #: The current worker pool instance.
     pool = None
     pool = None
 
 
@@ -189,187 +109,188 @@ class Consumer(object):
     #: as sending heartbeats.
     #: as sending heartbeats.
     timer = None
     timer = None
 
 
-    # Consumer state, can be RUN or CLOSE.
-    _state = None
+    class Namespace(bootsteps.Namespace):
+        name = 'Consumer'
+        default_steps = [
+            'celery.worker.consumer:Connection',
+            'celery.worker.consumer:Events',
+            'celery.worker.consumer:Heart',
+            'celery.worker.consumer:Control',
+            'celery.worker.consumer:Tasks',
+            'celery.worker.consumer:Evloop',
+        ]
+
+        def shutdown(self, parent):
+            self.restart(parent, 'Shutdown', 'shutdown')
 
 
     def __init__(self, ready_queue,
     def __init__(self, ready_queue,
-            init_callback=noop, send_events=False, hostname=None,
-            initial_prefetch_count=2, pool=None, app=None,
+            init_callback=noop, hostname=None,
+            pool=None, app=None,
             timer=None, controller=None, hub=None, amqheartbeat=None,
             timer=None, controller=None, hub=None, amqheartbeat=None,
-            **kwargs):
+            worker_options=None, **kwargs):
         self.app = app_or_default(app)
         self.app = app_or_default(app)
-        self.connection = None
-        self.task_consumer = None
         self.controller = controller
         self.controller = controller
-        self.broadcast_consumer = None
         self.ready_queue = ready_queue
         self.ready_queue = ready_queue
-        self.send_events = send_events
         self.init_callback = init_callback
         self.init_callback = init_callback
         self.hostname = hostname or socket.gethostname()
         self.hostname = hostname or socket.gethostname()
-        self.initial_prefetch_count = initial_prefetch_count
-        self.event_dispatcher = None
-        self.heart = None
         self.pool = pool
         self.pool = pool
-        self.timer = timer or timer2.default_timer
-        pidbox_state = AttributeDict(app=self.app,
-                                     hostname=self.hostname,
-                                     listener=self,     # pre 2.2
-                                     consumer=self)
-        self.pidbox_node = self.app.control.mailbox.Node(self.hostname,
-                                                         state=pidbox_state,
-                                                         handlers=Panel.data)
+        self.timer = timer or default_timer
+        self.strategies = {}
         conninfo = self.app.connection()
         conninfo = self.app.connection()
         self.connection_errors = conninfo.connection_errors
         self.connection_errors = conninfo.connection_errors
         self.channel_errors = conninfo.channel_errors
         self.channel_errors = conninfo.channel_errors
 
 
         self._does_info = logger.isEnabledFor(logging.INFO)
         self._does_info = logger.isEnabledFor(logging.INFO)
-        self.strategies = {}
-        if hub:
-            hub.on_init.append(self.on_poll_init)
-        self.hub = hub
         self._quick_put = self.ready_queue.put
         self._quick_put = self.ready_queue.put
-        self.amqheartbeat = amqheartbeat
-        if self.amqheartbeat is None:
-            self.amqheartbeat = self.app.conf.BROKER_HEARTBEAT
-        if not hub:
+
+        if hub:
+            self.amqheartbeat = amqheartbeat
+            if self.amqheartbeat is None:
+                self.amqheartbeat = self.app.conf.BROKER_HEARTBEAT
+            self.hub = hub
+            self.hub.on_init.append(self.on_poll_init)
+        else:
+            self.hub = None
             self.amqheartbeat = 0
             self.amqheartbeat = 0
 
 
+        if not hasattr(self, 'loop'):
+            self.loop = loops.asynloop if hub else loops.synloop
+
         if _detect_environment() == 'gevent':
         if _detect_environment() == 'gevent':
             # there's a gevent bug that causes timeouts to not be reset,
             # there's a gevent bug that causes timeouts to not be reset,
             # so if the connection timeout is exceeded once, it can NEVER
             # so if the connection timeout is exceeded once, it can NEVER
             # connect again.
             # connect again.
             self.app.conf.BROKER_CONNECTION_TIMEOUT = None
             self.app.conf.BROKER_CONNECTION_TIMEOUT = None
 
 
-    def update_strategies(self):
-        S = self.strategies
-        app = self.app
-        loader = app.loader
-        hostname = self.hostname
-        for name, task in self.app.tasks.iteritems():
-            S[name] = task.start_strategy(app, self)
-            task.__trace__ = build_tracer(name, task, loader, hostname)
+        self.steps = []
+        self.namespace = self.Namespace(
+            app=self.app, on_start=self.on_start, on_close=self.on_close,
+        )
+        self.namespace.apply(self, **worker_options or {})
 
 
     def start(self):
     def start(self):
-        """Start the consumer.
+        ns, loop = self.namespace, self.loop
+        while ns.state != CLOSE:
+            maybe_shutdown()
+            try:
+                ns.start(self)
+            except self.connection_errors + self.channel_errors:
+                maybe_shutdown()
+                if ns.state != CLOSE and self.connection:
+                    error(CONNECTION_RETRY, exc_info=True)
+                    ns.restart(self)
 
 
-        Automatically survives intermittent connection failure,
-        and will retry establishing the connection and restart
-        consuming messages.
+    def shutdown(self):
+        self.namespace.shutdown(self)
 
 
-        """
+    def stop(self):
+        self.namespace.stop(self)
 
 
-        self.init_callback(self)
+    def on_start(self):
+        self.update_strategies()
 
 
-        while self._state != CLOSE:
-            self.maybe_shutdown()
-            try:
-                self.reset_connection()
-                self.consume_messages()
-            except self.connection_errors + self.channel_errors:
-                error(RETRY_CONNECTION, exc_info=True)
+    def on_ready(self):
+        callback, self.init_callback = self.init_callback, None
+        if callback:
+            callback(self)
+
+    def loop_args(self):
+        return (self, self.connection, self.task_consumer,
+                self.strategies, self.namespace, self.hub, self.qos,
+                self.amqheartbeat, self.handle_unknown_message,
+                self.handle_unknown_task, self.handle_invalid_task)
 
 
     def on_poll_init(self, hub):
     def on_poll_init(self, hub):
         hub.update_readers(self.connection.eventmap)
         hub.update_readers(self.connection.eventmap)
         self.connection.transport.on_poll_init(hub.poller)
         self.connection.transport.on_poll_init(hub.poller)
 
 
-    def consume_messages(self, sleep=sleep, min=min, Empty=Empty,
-            hbrate=AMQHEARTBEAT_RATE):
-        """Consume messages forever (or until an exception is raised)."""
-
-        with self.hub as hub:
-            qos = self.qos
-            update_qos = qos.update
-            update_readers = hub.update_readers
-            readers, writers = hub.readers, hub.writers
-            poll = hub.poller.poll
-            fire_timers = hub.fire_timers
-            scheduled = hub.timer._queue
-            connection = self.connection
-            hb = self.amqheartbeat
-            hbtick = connection.heartbeat_check
-            on_poll_start = connection.transport.on_poll_start
-            on_poll_empty = connection.transport.on_poll_empty
-            strategies = self.strategies
-            drain_nowait = connection.drain_nowait
-            on_task_callbacks = hub.on_task
-            keep_draining = connection.transport.nb_keep_draining
-
-            if hb and connection.supports_heartbeats:
-                hub.timer.apply_interval(
-                    hb * 1000.0 / hbrate, hbtick, (hbrate, ))
-
-            def on_task_received(body, message):
-                if on_task_callbacks:
-                    [callback() for callback in on_task_callbacks]
-                try:
-                    name = body['task']
-                except (KeyError, TypeError):
-                    return self.handle_unknown_message(body, message)
-                try:
-                    strategies[name](message, body, message.ack_log_error)
-                except KeyError as exc:
-                    self.handle_unknown_task(body, message, exc)
-                except InvalidTaskError as exc:
-                    self.handle_invalid_task(body, message, exc)
-                #fire_timers()
-
-            self.task_consumer.callbacks = [on_task_received]
-            self.task_consumer.consume()
-
-            debug('Ready to accept tasks!')
-
-            while self._state != CLOSE and self.connection:
-                # shutdown if signal handlers told us to.
-                if state.should_stop:
-                    raise SystemExit()
-                elif state.should_terminate:
-                    raise SystemTerminate()
-
-                # fire any ready timers, this also returns
-                # the number of seconds until we need to fire timers again.
-                poll_timeout = fire_timers() if scheduled else 1
-
-                # We only update QoS when there is no more messages to read.
-                # This groups together qos calls, and makes sure that remote
-                # control commands will be prioritized over task messages.
-                if qos.prev != qos.value:
-                    update_qos()
-
-                update_readers(on_poll_start())
-                if readers or writers:
-                    connection.more_to_read = True
-                    while connection.more_to_read:
-                        try:
-                            events = poll(poll_timeout)
-                        except ValueError:  # Issue 882
-                            return
-                        if not events:
-                            on_poll_empty()
-                        for fileno, event in events or ():
-                            try:
-                                if event & READ:
-                                    readers[fileno](fileno, event)
-                                if event & WRITE:
-                                    writers[fileno](fileno, event)
-                                if event & ERR:
-                                    for handlermap in readers, writers:
-                                        try:
-                                            handlermap[fileno](fileno, event)
-                                        except KeyError:
-                                            pass
-                            except (KeyError, Empty):
-                                continue
-                            except socket.error:
-                                if self._state != CLOSE:  # pragma: no cover
-                                    raise
-                        if keep_draining:
-                            drain_nowait()
-                            poll_timeout = 0
-                        else:
-                            connection.more_to_read = False
-                else:
-                    # no sockets yet, startup is probably not done.
-                    sleep(min(poll_timeout, 0.1))
+    def on_decode_error(self, message, exc):
+        """Callback called if an error occurs while decoding
+        a message received.
+
+        Simply logs the error and acknowledges the message so it
+        doesn't enter a loop.
+
+        :param message: The message with errors.
+        :param exc: The original exception instance.
+
+        """
+        crit("Can't decode message body: %r (type:%r encoding:%r raw:%r')",
+             exc, message.content_type, message.content_encoding,
+             dump_body(message, message.body))
+        message.ack()
+
+    def on_close(self):
+        # Clear internal queues to get rid of old messages.
+        # They can't be acked anyway, as a delivery tag is specific
+        # to the current channel.
+        self.ready_queue.clear()
+        self.timer.clear()
+
+    def connect(self):
+        """Establish the broker connection.
+
+        Will retry establishing the connection if the
+        :setting:`BROKER_CONNECTION_RETRY` setting is enabled
+
+        """
+        conn = self.app.connection(heartbeat=self.amqheartbeat)
+
+        # Callback called for each retry while the connection
+        # can't be established.
+        def _error_handler(exc, interval, next_step=CONNECTION_RETRY_STEP):
+            if getattr(conn, 'alt', None) and interval == 0:
+                next_step = CONNECTION_FAILOVER
+            error(CONNECTION_ERROR, conn.as_uri(), exc,
+                  next_step.format(when=humanize_seconds(interval, 'in', ' ')))
+
+        # remember that the connection is lazy, it won't establish
+        # until it's needed.
+        if not self.app.conf.BROKER_CONNECTION_RETRY:
+            # retry disabled, just call connect directly.
+            conn.connect()
+            return conn
+
+        return conn.ensure_connection(_error_handler,
+                    self.app.conf.BROKER_CONNECTION_MAX_RETRIES,
+                    callback=maybe_shutdown)
+
+    def add_task_queue(self, queue, exchange=None, exchange_type=None,
+            routing_key=None, **options):
+        cset = self.task_consumer
+        try:
+            q = self.app.amqp.queues[queue]
+        except KeyError:
+            exchange = queue if exchange is None else exchange
+            exchange_type = 'direct' if exchange_type is None \
+                                     else exchange_type
+            q = self.app.amqp.queues.select_add(queue,
+                    exchange=exchange,
+                    exchange_type=exchange_type,
+                    routing_key=routing_key, **options)
+        if not cset.consuming_from(queue):
+            cset.add_queue(q)
+            cset.consume()
+            info('Started consuming from %r', queue)
+
+    def cancel_task_queue(self, queue):
+        self.app.amqp.queues.select_remove(queue)
+        self.task_consumer.cancel_by_queue(queue)
+
+    @property
+    def info(self):
+        """Returns information about this consumer instance
+        as a dict.
+
+        This is also the consumer related info returned by
+        ``celeryctl stats``.
+
+        """
+        conninfo = {}
+        if self.connection:
+            conninfo = self.connection.info()
+            conninfo.pop('password', None)  # don't send password.
+        return {'broker': conninfo, 'prefetch_count': self.qos.value}
 
 
     def on_task(self, task, task_reserved=task_reserved):
     def on_task(self, task, task_reserved=task_reserved):
         """Handle received task.
         """Handle received task.
@@ -395,7 +316,7 @@ class Consumer(object):
         if task.eta:
         if task.eta:
             eta = timezone.to_system(task.eta) if task.utc else task.eta
             eta = timezone.to_system(task.eta) if task.utc else task.eta
             try:
             try:
-                eta = timer2.to_timestamp(eta)
+                eta = to_timestamp(eta)
             except OverflowError as exc:
             except OverflowError as exc:
                 error("Couldn't convert eta %s to timestamp: %r. Task: %r",
                 error("Couldn't convert eta %s to timestamp: %r. Task: %r",
                       task.eta, exc, task.info(safe=True), exc_info=True)
                       task.eta, exc, task.info(safe=True), exc_info=True)
@@ -409,16 +330,6 @@ class Consumer(object):
             task_reserved(task)
             task_reserved(task)
             self._quick_put(task)
             self._quick_put(task)
 
 
-    def on_control(self, body, message):
-        """Process remote control command message."""
-        try:
-            self.pidbox_node.handle_message(body, message)
-        except KeyError as exc:
-            error('No such control command: %s', exc)
-        except Exception as exc:
-            error('Control command error: %r', exc, exc_info=True)
-            self.reset_pidbox_node()
-
     def apply_eta_task(self, task):
     def apply_eta_task(self, task):
         """Method called by the timer to apply a task with an
         """Method called by the timer to apply a task with an
         ETA/countdown."""
         ETA/countdown."""
@@ -444,307 +355,109 @@ class Consumer(object):
         error(INVALID_TASK_ERROR, exc, dump_body(message, body), exc_info=True)
         error(INVALID_TASK_ERROR, exc, dump_body(message, body), exc_info=True)
         message.reject_log_error(logger, self.connection_errors)
         message.reject_log_error(logger, self.connection_errors)
 
 
-    def receive_message(self, body, message):
-        """Handles incoming messages.
+    def update_strategies(self):
+        loader = self.app.loader
+        for name, task in self.app.tasks.iteritems():
+            self.strategies[name] = task.start_strategy(self.app, self)
+            task.__trace__ = build_tracer(name, task, loader, self.hostname)
 
 
-        :param body: The message body.
-        :param message: The kombu message object.
 
 
-        """
-        try:
-            name = body['task']
-        except (KeyError, TypeError):
-            return self.handle_unknown_message(body, message)
+class Connection(bootsteps.StartStopStep):
 
 
-        try:
-            self.strategies[name](message, body, message.ack_log_error)
-        except KeyError as exc:
-            self.handle_unknown_task(body, message, exc)
-        except InvalidTaskError as exc:
-            self.handle_invalid_task(body, message, exc)
-
-    def maybe_conn_error(self, fun):
-        """Applies function but ignores any connection or channel
-        errors raised."""
-        try:
-            fun()
-        except (AttributeError, ) + \
-                self.connection_errors + \
-                self.channel_errors:
-            pass
+    def __init__(self, c, **kwargs):
+        c.connection = None
 
 
-    def close_connection(self):
-        """Closes the current broker connection and all open channels."""
+    def start(self, c):
+        c.connection = c.connect()
+        info('Connected to %s', c.connection.as_uri())
 
 
+    def shutdown(self, c):
         # We must set self.connection to None here, so
         # We must set self.connection to None here, so
         # that the green pidbox thread exits.
         # that the green pidbox thread exits.
-        connection, self.connection = self.connection, None
-
-        if self.task_consumer:
-            debug('Closing consumer channel...')
-            self.task_consumer = \
-                    self.maybe_conn_error(self.task_consumer.close)
-
-        self.stop_pidbox_node()
-
+        connection, c.connection = c.connection, None
         if connection:
         if connection:
-            debug('Closing broker connection...')
-            self.maybe_conn_error(connection.close)
-
-    def stop_consumers(self, close_connection=True, join=True):
-        """Stop consuming tasks and broadcast commands, also stops
-        the heartbeat thread and event dispatcher.
+            ignore_errors(connection, connection.close)
 
 
-        :keyword close_connection: Set to False to skip closing the broker
-                                    connection.
-
-        """
-        if not self._state == RUN:
-            return
-
-        if self.heart:
-            # Stop the heartbeat thread if it's running.
-            debug('Heart: Going into cardiac arrest...')
-            self.heart = self.heart.stop()
-
-        debug('Cancelling task consumer...')
-        if join and self.task_consumer:
-            self.maybe_conn_error(self.task_consumer.cancel)
-
-        if self.event_dispatcher:
-            debug('Shutting down event dispatcher...')
-            self.event_dispatcher = \
-                    self.maybe_conn_error(self.event_dispatcher.close)
-
-        debug('Cancelling broadcast consumer...')
-        if join and self.broadcast_consumer:
-            self.maybe_conn_error(self.broadcast_consumer.cancel)
-
-        if close_connection:
-            self.close_connection()
-
-    def on_decode_error(self, message, exc):
-        """Callback called if an error occurs while decoding
-        a message received.
-
-        Simply logs the error and acknowledges the message so it
-        doesn't enter a loop.
-
-        :param message: The message with errors.
-        :param exc: The original exception instance.
-
-        """
-        crit("Can't decode message body: %r (type:%r encoding:%r raw:%r')",
-             exc, message.content_type, message.content_encoding,
-             dump_body(message, message.body))
-        message.ack()
-
-    def reset_pidbox_node(self):
-        """Sets up the process mailbox."""
-        self.stop_pidbox_node()
-        # close previously opened channel if any.
-        if self.pidbox_node.channel:
-            try:
-                self.pidbox_node.channel.close()
-            except self.connection_errors + self.channel_errors:
-                pass
-
-        if self.pool is not None and self.pool.is_green:
-            return self.pool.spawn_n(self._green_pidbox_node)
-        self.pidbox_node.channel = self.connection.channel()
-        self.broadcast_consumer = self.pidbox_node.listen(
-                                        callback=self.on_control)
-
-    def stop_pidbox_node(self):
-        if self._pidbox_node_stopped:
-            self._pidbox_node_shutdown.set()
-            debug('Waiting for broadcast thread to shutdown...')
-            self._pidbox_node_stopped.wait()
-            self._pidbox_node_stopped = self._pidbox_node_shutdown = None
-        elif self.broadcast_consumer:
-            debug('Closing broadcast channel...')
-            self.broadcast_consumer = \
-                self.maybe_conn_error(self.broadcast_consumer.channel.close)
-
-    def _green_pidbox_node(self):
-        """Sets up the process mailbox when running in a greenlet
-        environment."""
-        # THIS CODE IS TERRIBLE
-        # Luckily work has already started rewriting the Consumer for 4.0.
-        self._pidbox_node_shutdown = threading.Event()
-        self._pidbox_node_stopped = threading.Event()
-        try:
-            with self._open_connection() as conn:
-                info('pidbox: Connected to %s.', conn.as_uri())
-                self.pidbox_node.channel = conn.default_channel
-                self.broadcast_consumer = self.pidbox_node.listen(
-                                            callback=self.on_control)
-                with self.broadcast_consumer:
-                    while not self._pidbox_node_shutdown.isSet():
-                        try:
-                            conn.drain_events(timeout=1.0)
-                        except socket.timeout:
-                            pass
-        finally:
-            self._pidbox_node_stopped.set()
-
-    def reset_connection(self):
-        """Re-establish the broker connection and set up consumers,
-        heartbeat and the event dispatcher."""
-        debug('Re-establishing connection to the broker...')
-        self.stop_consumers(join=False)
-
-        # Clear internal queues to get rid of old messages.
-        # They can't be acked anyway, as a delivery tag is specific
-        # to the current channel.
-        self.ready_queue.clear()
-        self.timer.clear()
 
 
-        # Re-establish the broker connection and setup the task consumer.
-        self.connection = self._open_connection()
-        info('consumer: Connected to %s.', self.connection.as_uri())
-        self.task_consumer = self.app.amqp.TaskConsumer(self.connection,
-                                    on_decode_error=self.on_decode_error)
-        # QoS: Reset prefetch window.
-        self.qos = QoS(self.task_consumer, self.initial_prefetch_count)
-        self.qos.update()
+class Events(bootsteps.StartStopStep):
+    requires = (Connection, )
 
 
-        # Setup the process mailbox.
-        self.reset_pidbox_node()
+    def __init__(self, c, send_events=None, **kwargs):
+        self.send_events = send_events
+        c.event_dispatcher = None
 
 
+    def start(self, c):
         # Flush events sent while connection was down.
         # Flush events sent while connection was down.
-        prev_event_dispatcher = self.event_dispatcher
-        self.event_dispatcher = self.app.events.Dispatcher(self.connection,
-                                                hostname=self.hostname,
-                                                enabled=self.send_events)
-        if prev_event_dispatcher:
-            self.event_dispatcher.copy_buffer(prev_event_dispatcher)
-            self.event_dispatcher.flush()
-
-        # Restart heartbeat thread.
-        self.restart_heartbeat()
-
-        # reload all task's execution strategies.
-        self.update_strategies()
-
-        # We're back!
-        self._state = RUN
-
-    def restart_heartbeat(self):
-        """Restart the heartbeat thread.
-
-        This thread sends heartbeat events at intervals so monitors
-        can tell if the worker is off-line/missing.
-
-        """
-        self.heart = Heart(self.timer, self.event_dispatcher)
-        self.heart.start()
-
-    def _open_connection(self):
-        """Establish the broker connection.
-
-        Will retry establishing the connection if the
-        :setting:`BROKER_CONNECTION_RETRY` setting is enabled
-
-        """
-        conn = self.app.connection(heartbeat=self.amqheartbeat)
-
-        # Callback called for each retry while the connection
-        # can't be established.
-        def _error_handler(exc, interval, next_step=CONNECTION_RETRY):
-            if getattr(conn, 'alt', None) and interval == 0:
-                next_step = CONNECTION_FAILOVER
-            error(CONNECTION_ERROR, conn.as_uri(), exc,
-                  next_step.format(when=humanize_seconds(interval, 'in', ' ')))
+        prev = c.event_dispatcher
+        dis = c.event_dispatcher = c.app.events.Dispatcher(
+            c.connection, hostname=c.hostname, enabled=self.send_events,
+        )
+        if prev:
+            dis.copy_buffer(prev)
+            dis.flush()
 
 
-        # remember that the connection is lazy, it won't establish
-        # until it's needed.
-        if not self.app.conf.BROKER_CONNECTION_RETRY:
-            # retry disabled, just call connect directly.
-            conn.connect()
-            return conn
-
-        return conn.ensure_connection(_error_handler,
-                    self.app.conf.BROKER_CONNECTION_MAX_RETRIES,
-                    callback=self.maybe_shutdown)
+    def stop(self, c):
+        if c.event_dispatcher:
+            ignore_errors(c, c.event_dispatcher.close)
+            c.event_dispatcher = None
+    shutdown = stop
 
 
-    def stop(self):
-        """Stop consuming.
 
 
-        Does not close the broker connection, so be sure to call
-        :meth:`close_connection` when you are finished with it.
+class Heart(bootsteps.StartStopStep):
+    requires = (Events, )
 
 
-        """
-        # Notifies other threads that this instance can't be used
-        # anymore.
-        self.close()
-        debug('Stopping consumers...')
-        self.stop_consumers(close_connection=False, join=True)
+    def __init__(self, c, **kwargs):
+        c.heart = None
 
 
-    def close(self):
-        self._state = CLOSE
+    def start(self, c):
+        c.heart = heartbeat.Heart(c.timer, c.event_dispatcher)
+        c.heart.start()
 
 
-    def maybe_shutdown(self):
-        if state.should_stop:
-            raise SystemExit()
-        elif state.should_terminate:
-            raise SystemTerminate()
+    def stop(self, c):
+        c.heart = c.heart and c.heart.stop()
+    shutdown = stop
 
 
-    def add_task_queue(self, queue, exchange=None, exchange_type=None,
-            routing_key=None, **options):
-        cset = self.task_consumer
-        try:
-            q = self.app.amqp.queues[queue]
-        except KeyError:
-            exchange = queue if exchange is None else exchange
-            exchange_type = 'direct' if exchange_type is None \
-                                     else exchange_type
-            q = self.app.amqp.queues.select_add(queue,
-                    exchange=exchange,
-                    exchange_type=exchange_type,
-                    routing_key=routing_key, **options)
-        if not cset.consuming_from(queue):
-            cset.add_queue(q)
-            cset.consume()
-            logger.info('Started consuming from %r', queue)
 
 
-    def cancel_task_queue(self, queue):
-        self.app.amqp.queues.select_remove(queue)
-        self.task_consumer.cancel_by_queue(queue)
+class Control(bootsteps.StartStopStep):
+    requires = (Events, )
 
 
-    @property
-    def info(self):
-        """Returns information about this consumer instance
-        as a dict.
+    def __init__(self, c, **kwargs):
+        self.is_green = c.pool is not None and c.pool.is_green
+        self.box = (pidbox.gPidbox if self.is_green else pidbox.Pidbox)(c)
+        self.start = self.box.start
+        self.stop = self.box.stop
+        self.shutdown = self.box.shutdown
 
 
-        This is also the consumer related info returned by
-        ``celeryctl stats``.
 
 
-        """
-        conninfo = {}
-        if self.connection:
-            conninfo = self.connection.info()
-            conninfo.pop('password', None)  # don't send password.
-        return {'broker': conninfo, 'prefetch_count': self.qos.value}
+class Tasks(bootsteps.StartStopStep):
+    requires = (Control, )
 
 
+    def __init__(self, c, initial_prefetch_count=2, **kwargs):
+        c.task_consumer = c.qos = None
+        self.initial_prefetch_count = initial_prefetch_count
 
 
-class BlockingConsumer(Consumer):
+    def start(self, c):
+        c.task_consumer = c.app.amqp.TaskConsumer(
+            c.connection, on_decode_error=c.on_decode_error,
+        )
+        c.qos = QoS(c.task_consumer, self.initial_prefetch_count)
+        c.qos.update()  # set initial prefetch count
+
+    def stop(self, c):
+        if c.task_consumer:
+            debug('Cancelling task consumer...')
+            ignore_errors(c, c.task_consumer.cancel)
+
+    def shutdown(self, c):
+        if c.task_consumer:
+            self.stop(c)
+            debug('Closing consumer channel...')
+            ignore_errors(c, c.task_consumer.close)
+            c.task_consumer = None
 
 
-    def consume_messages(self):
-        # receive_message handles incoming messages.
-        self.task_consumer.register_callback(self.receive_message)
-        self.task_consumer.consume()
 
 
-        debug('Ready to accept tasks!')
+class Evloop(bootsteps.StartStopStep):
+    last = True
 
 
-        while self._state != CLOSE and self.connection:
-            self.maybe_shutdown()
-            if self.qos.prev != self.qos.value:     # pragma: no cover
-                self.qos.update()
-            try:
-                self.connection.drain_events(timeout=10.0)
-            except socket.timeout:
-                pass
-            except socket.error:
-                if self._state != CLOSE:            # pragma: no cover
-                    raise
+    def start(self, c):
+        c.loop(*c.loop_args())

+ 2 - 2
celery/worker/hub.py

@@ -132,11 +132,11 @@ class Hub(object):
         self.on_task = []
         self.on_task = []
 
 
     def start(self):
     def start(self):
-        """Called by StartStopComponent at worker startup."""
+        """Called by Hub bootstep at worker startup."""
         self.poller = eventio.poll()
         self.poller = eventio.poll()
 
 
     def stop(self):
     def stop(self):
-        """Called by StartStopComponent at worker shutdown."""
+        """Called by Hub bootstep at worker shutdown."""
         self.poller.close()
         self.poller.close()
 
 
     def init(self):
     def init(self):

+ 156 - 0
celery/worker/loops.py

@@ -0,0 +1,156 @@
+"""
+celery.worker.loop
+~~~~~~~~~~~~~~~~~~
+
+The consumers highly-optimized inner loop.
+
+"""
+from __future__ import absolute_import
+
+import socket
+
+from time import sleep
+from Queue import Empty
+
+from kombu.utils.eventio import READ, WRITE, ERR
+
+from celery.bootsteps import CLOSE
+from celery.exceptions import InvalidTaskError, SystemTerminate
+
+from . import state
+
+#: Heartbeat check is called every heartbeat_seconds' / rate'.
+AMQHEARTBEAT_RATE = 2.0
+
+
+def asynloop(obj, connection, consumer, strategies, ns, hub, qos,
+        heartbeat, handle_unknown_message, handle_unknown_task,
+        handle_invalid_task, sleep=sleep, min=min, Empty=Empty,
+        hbrate=AMQHEARTBEAT_RATE):
+    """Non-blocking eventloop consuming messages until connection is lost,
+    or shutdown is requested."""
+
+    with hub as hub:
+        update_qos = qos.update
+        update_readers = hub.update_readers
+        readers, writers = hub.readers, hub.writers
+        poll = hub.poller.poll
+        fire_timers = hub.fire_timers
+        scheduled = hub.timer._queue
+        hbtick = connection.heartbeat_check
+        on_poll_start = connection.transport.on_poll_start
+        on_poll_empty = connection.transport.on_poll_empty
+        drain_nowait = connection.drain_nowait
+        on_task_callbacks = hub.on_task
+        keep_draining = connection.transport.nb_keep_draining
+
+        if heartbeat and connection.supports_heartbeats:
+            hub.timer.apply_interval(
+                heartbeat * 1000.0 / hbrate, hbtick, (hbrate, ))
+
+        def on_task_received(body, message):
+            if on_task_callbacks:
+                [callback() for callback in on_task_callbacks]
+            try:
+                name = body['task']
+            except (KeyError, TypeError):
+                return handle_unknown_message(body, message)
+            try:
+                strategies[name](message, body, message.ack_log_error)
+            except KeyError as exc:
+                handle_unknown_task(body, message, exc)
+            except InvalidTaskError as exc:
+                handle_invalid_task(body, message, exc)
+
+        consumer.callbacks = [on_task_received]
+        consumer.consume()
+        obj.on_ready()
+
+        while ns.state != CLOSE and obj.connection:
+            # shutdown if signal handlers told us to.
+            if state.should_stop:
+                raise SystemExit()
+            elif state.should_terminate:
+                raise SystemTerminate()
+
+            # fire any ready timers, this also returns
+            # the number of seconds until we need to fire timers again.
+            poll_timeout = fire_timers() if scheduled else 1
+
+            # We only update QoS when there is no more messages to read.
+            # This groups together qos calls, and makes sure that remote
+            # control commands will be prioritized over task messages.
+            if qos.prev != qos.value:
+                update_qos()
+
+            update_readers(on_poll_start())
+            if readers or writers:
+                connection.more_to_read = True
+                while connection.more_to_read:
+                    try:
+                        events = poll(poll_timeout)
+                    except ValueError:  # Issue 882
+                        return
+                    if not events:
+                        on_poll_empty()
+                    for fileno, event in events or ():
+                        try:
+                            if event & READ:
+                                readers[fileno](fileno, event)
+                            if event & WRITE:
+                                writers[fileno](fileno, event)
+                            if event & ERR:
+                                for handlermap in readers, writers:
+                                    try:
+                                        handlermap[fileno](fileno, event)
+                                    except KeyError:
+                                        pass
+                        except (KeyError, Empty):
+                            continue
+                        except socket.error:
+                            if ns.state != CLOSE:  # pragma: no cover
+                                raise
+                    if keep_draining:
+                        drain_nowait()
+                        poll_timeout = 0
+                    else:
+                        connection.more_to_read = False
+            else:
+                # no sockets yet, startup is probably not done.
+                sleep(min(poll_timeout, 0.1))
+
+
+def synloop(obj, connection, consumer, strategies, ns, hub, qos,
+        heartbeat, handle_unknown_message, handle_unknown_task,
+        handle_invalid_task, **kwargs):
+    """Fallback blocking eventloop for transports that doesn't support AIO."""
+
+    def on_task_received(body, message):
+        try:
+            name = body['task']
+        except (KeyError, TypeError):
+            return handle_unknown_message(body, message)
+
+        try:
+            strategies[name](message, body, message.ack_log_error)
+        except KeyError as exc:
+            handle_unknown_task(body, message, exc)
+        except InvalidTaskError as exc:
+            handle_invalid_task(body, message, exc)
+
+    consumer.register_callback(on_task_received)
+    consumer.consume()
+
+    obj.on_ready()
+
+    while ns.state != CLOSE and obj.connection:
+        state.maybe_shutdown()
+        if qos.prev != qos.value:         # pragma: no cover
+            qos.update()
+        try:
+            connection.drain_events(timeout=2.0)
+        except socket.timeout:
+            pass
+        except socket.error:
+            if ns.state != CLOSE:  # pragma: no cover
+                raise

+ 4 - 4
celery/worker/mediator.py

@@ -20,17 +20,17 @@ import logging
 from Queue import Empty
 from Queue import Empty
 
 
 from celery.app import app_or_default
 from celery.app import app_or_default
+from celery.bootsteps import StartStopStep
 from celery.utils.threads import bgThread
 from celery.utils.threads import bgThread
 from celery.utils.log import get_logger
 from celery.utils.log import get_logger
 
 
-from .bootsteps import StartStopComponent
+from . import components
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
-class WorkerComponent(StartStopComponent):
-    name = 'worker.mediator'
-    requires = ('pool', 'queues', )
+class WorkerComponent(StartStopStep):
+    requires = (components.Pool, components.Queues, )
 
 
     def __init__(self, w, **kwargs):
     def __init__(self, w, **kwargs):
         w.mediator = None
         w.mediator = None

+ 103 - 0
celery/worker/pidbox.py

@@ -0,0 +1,103 @@
+from __future__ import absolute_import
+
+import socket
+import threading
+
+from kombu.common import ignore_errors
+
+from celery.datastructures import AttributeDict
+from celery.utils.log import get_logger
+
+from . import control
+
+logger = get_logger(__name__)
+debug, error, info = logger.debug, logger.error, logger.info
+
+
+class Pidbox(object):
+    consumer = None
+
+    def __init__(self, c):
+        self.c = c
+        self.hostname = c.hostname
+        self.node = c.app.control.mailbox.Node(c.hostname,
+            handlers=control.Panel.data,
+            state=AttributeDict(app=c.app, hostname=c.hostname, consumer=c),
+        )
+
+    def on_message(self, body, message):
+        try:
+            self.node.handle_message(body, message)
+        except KeyError as exc:
+            error('No such control command: %s', exc)
+        except Exception as exc:
+            error('Control command error: %r', exc, exc_info=True)
+            self.reset()
+
+    def start(self, c):
+        self.node.channel = c.connection.channel()
+        self.consumer = self.node.listen(callback=self.on_message)
+
+    def stop(self, c):
+        self.consumer = self._close_channel(c)
+
+    def reset(self):
+        """Sets up the process mailbox."""
+        self.stop(self.c)
+        self.start(self.c)
+
+    def _close_channel(self, c):
+        if self.node and self.node.channel:
+            ignore_errors(c, self.node.channel.close)
+
+    def shutdown(self, c):
+        if self.consumer:
+            debug('Cancelling broadcast consumer...')
+            ignore_errors(c, self.consumer.cancel)
+        self.stop(self.c)
+
+
+class gPidbox(Pidbox):
+    _node_shutdown = None
+    _node_stopped = None
+    _resets = 0
+
+    def start(self, c):
+        c.pool.spawn_n(self.loop, c)
+
+    def stop(self, c):
+        if self._node_stopped:
+            self._node_shutdown.set()
+            debug('Waiting for broadcast thread to shutdown...')
+            self._node_stopped.wait()
+            self._node_stopped = self._node_shutdown = None
+        super(gPidbox, self).stop(c)
+
+    def reset(self):
+        self._resets += 1
+
+    def _do_reset(self, c, connection):
+        self._close_channel(c)
+        self.node.channel = connection.channel()
+        self.consumer = self.node.listen(callback=self.on_message)
+        self.consumer.consume()
+
+    def loop(self, c):
+        resets = [self._resets]
+        shutdown = self._node_shutdown = threading.Event()
+        stopped = self._node_stopped = threading.Event()
+        try:
+            with c.connect() as connection:
+
+                info('pidbox: Connected to %s.', connection.as_uri())
+                self._do_reset(c, connection)
+                while not shutdown.is_set() and c.connection:
+                    if resets[0] < self._resets:
+                        resets[0] += 1
+                        self._do_reset(c, connection)
+                    try:
+                        connection.drain_events(timeout=1.0)
+                    except socket.timeout:
+                        pass
+        finally:
+            stopped.set()

+ 8 - 0
celery/worker/state.py

@@ -20,6 +20,7 @@ from collections import defaultdict
 from kombu.utils import cached_property
 from kombu.utils import cached_property
 
 
 from celery import __version__
 from celery import __version__
+from celery.exceptions import SystemTerminate
 from celery.datastructures import LimitedSet
 from celery.datastructures import LimitedSet
 
 
 #: Worker software/platform information.
 #: Worker software/platform information.
@@ -53,6 +54,13 @@ should_stop = False
 should_terminate = False
 should_terminate = False
 
 
 
 
+def maybe_shutdown():
+    if should_stop:
+        raise SystemExit()
+    elif should_terminate:
+        raise SystemTerminate()
+
+
 def task_accepted(request):
 def task_accepted(request):
     """Updates global state when a task has been accepted."""
     """Updates global state when a task has been accepted."""
     active_requests.add(request)
     active_requests.add(request)

+ 12 - 1
docs/configuration.rst

@@ -1408,9 +1408,20 @@ CELERYD_BOOT_STEPS
 ~~~~~~~~~~~~~~~~~~
 ~~~~~~~~~~~~~~~~~~
 
 
 This setting enables you to add additional components to the worker process.
 This setting enables you to add additional components to the worker process.
-It should be a list of module names with :class:`celery.abstract.Component`
+It should be a list of module names with
+:class:`celery.bootsteps.Step`
 classes, that augments functionality in the worker.
 classes, that augments functionality in the worker.
 
 
+.. setting:: CELERYD_CONSUMER_BOOT_STEPS
+
+CELERYD_CONSUMER_BOOT_STEPS
+~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+This setting enables you to add additional components to the workers consumer.
+It should be a list of module names with
+:class:`celery.bootsteps.Step`` classes, that augments
+functionality in the consumer.
+
 .. setting:: CELERYD_POOL
 .. setting:: CELERYD_POOL
 
 
 CELERYD_POOL
 CELERYD_POOL

+ 0 - 1
docs/internals/reference/index.rst

@@ -21,7 +21,6 @@
     celery.worker.strategy
     celery.worker.strategy
     celery.worker.autoreload
     celery.worker.autoreload
     celery.worker.autoscale
     celery.worker.autoscale
-    celery.worker.bootsteps
     celery.concurrency
     celery.concurrency
     celery.concurrency.solo
     celery.concurrency.solo
     celery.concurrency.processes
     celery.concurrency.processes

+ 3 - 3
docs/internals/reference/celery.worker.bootsteps.rst → docs/reference/celery.bootsteps.rst

@@ -1,11 +1,11 @@
 ==========================================
 ==========================================
- celery.worker.bootsteps
+ celery.bootsteps
 ==========================================
 ==========================================
 
 
 .. contents::
 .. contents::
     :local:
     :local:
-.. currentmodule:: celery.worker.bootsteps
+.. currentmodule:: celery.bootsteps
 
 
-.. automodule:: celery.worker.bootsteps
+.. automodule:: celery.bootsteps
     :members:
     :members:
     :undoc-members:
     :undoc-members: