Browse Source

Consumer bootsteps working!!! :happy:

Ask Solem 12 years ago
parent
commit
62443a1aa8

+ 26 - 26
celery/tests/worker/test_bootsteps.py

@@ -7,29 +7,29 @@ from celery.worker import bootsteps
 from celery.tests.utils import AppCase, Case
 from celery.tests.utils import AppCase, Case
 
 
 
 
-class test_Component(Case):
+class test_Step(Case):
 
 
-    class Def(bootsteps.Component):
-        name = 'test_Component.Def'
+    class Def(bootsteps.Step):
+        name = 'test_Step.Def'
 
 
-    def test_components_must_be_named(self):
+    def test_steps_must_be_named(self):
         with self.assertRaises(NotImplementedError):
         with self.assertRaises(NotImplementedError):
 
 
-            class X(bootsteps.Component):
+            class X(bootsteps.Step):
                 pass
                 pass
 
 
-        class Y(bootsteps.Component):
+        class Y(bootsteps.Step):
             abstract = True
             abstract = True
 
 
     def test_namespace_name(self, ns='test_namespace_name'):
     def test_namespace_name(self, ns='test_namespace_name'):
 
 
-        class X(bootsteps.Component):
+        class X(bootsteps.Step):
             namespace = ns
             namespace = ns
             name = 'X'
             name = 'X'
         self.assertEqual(X.namespace, ns)
         self.assertEqual(X.namespace, ns)
         self.assertEqual(X.name, 'X')
         self.assertEqual(X.name, 'X')
 
 
-        class Y(bootsteps.Component):
+        class Y(bootsteps.Step):
             name = '%s.Y' % (ns, )
             name = '%s.Y' % (ns, )
         self.assertEqual(Y.namespace, ns)
         self.assertEqual(Y.namespace, ns)
         self.assertEqual(Y.name, 'Y')
         self.assertEqual(Y.name, 'Y')
@@ -70,13 +70,13 @@ class test_Component(Case):
         self.assertFalse(x.create.call_count)
         self.assertFalse(x.create.call_count)
 
 
 
 
-class test_StartStopComponent(Case):
+class test_StartStopStep(Case):
 
 
-    class Def(bootsteps.StartStopComponent):
-        name = 'test_StartStopComponent.Def'
+    class Def(bootsteps.StartStopStep):
+        name = 'test_StartStopStep.Def'
 
 
     def setUp(self):
     def setUp(self):
-        self.components = []
+        self.steps = []
 
 
     def test_start__stop(self):
     def test_start__stop(self):
         x = self.Def(self)
         x = self.Def(self)
@@ -84,10 +84,10 @@ class test_StartStopComponent(Case):
 
 
         # include creates the underlying object and sets
         # include creates the underlying object and sets
         # its x.obj attribute to it, as well as appending
         # its x.obj attribute to it, as well as appending
-        # it to the parent.components list.
+        # it to the parent.steps list.
         x.include(self)
         x.include(self)
-        self.assertTrue(self.components)
-        self.assertIs(self.components[0], x)
+        self.assertTrue(self.steps)
+        self.assertIs(self.steps[0], x)
 
 
         x.start(self)
         x.start(self)
         x.obj.start.assert_called_with()
         x.obj.start.assert_called_with()
@@ -99,7 +99,7 @@ class test_StartStopComponent(Case):
         x = self.Def(self)
         x = self.Def(self)
         x.enabled = False
         x.enabled = False
         x.include(self)
         x.include(self)
-        self.assertFalse(self.components)
+        self.assertFalse(self.steps)
 
 
     def test_terminate(self):
     def test_terminate(self):
         x = self.Def(self)
         x = self.Def(self)
@@ -128,15 +128,15 @@ class test_Namespace(AppCase):
         def import_module(self, module):
         def import_module(self, module):
             self.imported.append(module)
             self.imported.append(module)
 
 
-    def test_components_added_to_unclaimed(self):
+    def test_steps_added_to_unclaimed(self):
 
 
-        class tnA(bootsteps.Component):
+        class tnA(bootsteps.Step):
             name = 'test_Namespace.A'
             name = 'test_Namespace.A'
 
 
-        class tnB(bootsteps.Component):
+        class tnB(bootsteps.Step):
             name = 'test_Namespace.B'
             name = 'test_Namespace.B'
 
 
-        class xxA(bootsteps.Component):
+        class xxA(bootsteps.Step):
             name = 'xx.A'
             name = 'xx.A'
 
 
         self.assertIn('A', self.NS._unclaimed['test_Namespace'])
         self.assertIn('A', self.NS._unclaimed['test_Namespace'])
@@ -166,18 +166,18 @@ class test_Namespace(AppCase):
             def modules(self):
             def modules(self):
                 return ['A', 'B']
                 return ['A', 'B']
 
 
-        class A(bootsteps.Component):
+        class A(bootsteps.Step):
             name = 'test_apply.A'
             name = 'test_apply.A'
             requires = ['C']
             requires = ['C']
 
 
-        class B(bootsteps.Component):
+        class B(bootsteps.Step):
             name = 'test_apply.B'
             name = 'test_apply.B'
 
 
-        class C(bootsteps.Component):
+        class C(bootsteps.Step):
             name = 'test_apply.C'
             name = 'test_apply.C'
             requires = ['B']
             requires = ['B']
 
 
-        class D(bootsteps.Component):
+        class D(bootsteps.Step):
             name = 'test_apply.D'
             name = 'test_apply.D'
             last = True
             last = True
 
 
@@ -185,7 +185,7 @@ class test_Namespace(AppCase):
         x.import_module = Mock()
         x.import_module = Mock()
         x.apply(self)
         x.apply(self)
 
 
-        self.assertItemsEqual(x.components.values(), [A, B, C, D])
+        self.assertItemsEqual(x.steps.values(), [A, B, C, D])
         self.assertTrue(x.import_module.call_count)
         self.assertTrue(x.import_module.call_count)
 
 
         for boot_step in x.boot_steps:
         for boot_step in x.boot_steps:
@@ -203,7 +203,7 @@ class test_Namespace(AppCase):
         import os
         import os
         self.assertIs(x.import_module('os'), os)
         self.assertIs(x.import_module('os'), os)
 
 
-    def test_find_last_but_no_components(self):
+    def test_find_last_but_no_steps(self):
 
 
         class MyNS(bootsteps.Namespace):
         class MyNS(bootsteps.Namespace):
             name = 'qwejwioqjewoqiej'
             name = 'qwejwioqjewoqiej'

+ 132 - 105
celery/tests/worker/test_worker.py

@@ -9,7 +9,7 @@ from Queue import Empty
 
 
 from billiard.exceptions import WorkerLostError
 from billiard.exceptions import WorkerLostError
 from kombu import Connection
 from kombu import Connection
-from kombu.common import QoS, PREFETCH_COUNT_MAX
+from kombu.common import QoS, PREFETCH_COUNT_MAX, ignore_errors
 from kombu.exceptions import StdChannelError
 from kombu.exceptions import StdChannelError
 from kombu.transport.base import Message
 from kombu.transport.base import Message
 from mock import Mock, patch
 from mock import Mock, patch
@@ -24,41 +24,55 @@ from celery.task import task as task_dec
 from celery.task import periodic_task as periodic_task_dec
 from celery.task import periodic_task as periodic_task_dec
 from celery.utils import uuid
 from celery.utils import uuid
 from celery.worker import WorkController
 from celery.worker import WorkController
-from celery.worker.components import Queues, Timers, EvLoop, Pool
-from celery.worker.bootsteps import RUN, CLOSE, TERMINATE, StartStopComponent
+from celery.worker.components import Queues, Timers, Hub, Pool
+from celery.worker.bootsteps import RUN, CLOSE, TERMINATE, StartStopStep
 from celery.worker.buckets import FastQueue
 from celery.worker.buckets import FastQueue
 from celery.worker.job import Request
 from celery.worker.job import Request
+from celery.worker import consumer
 from celery.worker.consumer import Consumer
 from celery.worker.consumer import Consumer
-from celery.worker.consumer import components as consumer_components
 from celery.utils.serialization import pickle
 from celery.utils.serialization import pickle
 from celery.utils.timer2 import Timer
 from celery.utils.timer2 import Timer
 
 
 from celery.tests.utils import AppCase, Case
 from celery.tests.utils import AppCase, Case
 
 
 
 
+def MockStep(step=None):
+    step = Mock() if step is None else step
+    step.namespace = Mock()
+    step.namespace.name = 'MockNS'
+    step.name = 'MockStep'
+    return step
+
+
 class PlaceHolder(object):
 class PlaceHolder(object):
         pass
         pass
 
 
 
 
-def find_component(obj, typ):
+def find_step(obj, typ):
     for c in obj.namespace.boot_steps:
     for c in obj.namespace.boot_steps:
         if isinstance(c, typ):
         if isinstance(c, typ):
             return c
             return c
-    raise Exception('Instance %s has no %s component' % (obj, typ))
+    raise Exception('Instance %s has no step %s' % (obj, typ))
 
 
 
 
-class MyKombuConsumer(Consumer):
+class _MyKombuConsumer(Consumer):
     broadcast_consumer = Mock()
     broadcast_consumer = Mock()
     task_consumer = Mock()
     task_consumer = Mock()
 
 
     def __init__(self, *args, **kwargs):
     def __init__(self, *args, **kwargs):
         kwargs.setdefault('pool', BasePool(2))
         kwargs.setdefault('pool', BasePool(2))
-        super(MyKombuConsumer, self).__init__(*args, **kwargs)
+        super(_MyKombuConsumer, self).__init__(*args, **kwargs)
 
 
     def restart_heartbeat(self):
     def restart_heartbeat(self):
         self.heart = None
         self.heart = None
 
 
 
 
+class MyKombuConsumer(Consumer):
+
+    def loop(self, *args, **kwargs):
+        pass
+
+
 class MockNode(object):
 class MockNode(object):
     commands = []
     commands = []
 
 
@@ -246,7 +260,7 @@ class test_Consumer(Case):
 
 
         l.namespace.state = RUN
         l.namespace.state = RUN
         l.event_dispatcher = None
         l.event_dispatcher = None
-        l.restart()
+        l.namespace.restart(l)
         self.assertTrue(l.connection)
         self.assertTrue(l.connection)
 
 
         l.namespace.state = RUN
         l.namespace.state = RUN
@@ -256,7 +270,7 @@ class test_Consumer(Case):
 
 
         l.namespace.start(l)
         l.namespace.start(l)
         self.assertIsInstance(l.connection, Connection)
         self.assertIsInstance(l.connection, Connection)
-        l.restart()
+        l.namespace.restart(l)
 
 
         l.stop()
         l.stop()
         l.shutdown()
         l.shutdown()
@@ -266,9 +280,9 @@ class test_Consumer(Case):
     def test_close_connection(self):
     def test_close_connection(self):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         l.namespace.state = RUN
         l.namespace.state = RUN
-        comp = find_component(l, consumer_components.ConsumerConnection)
+        step = find_step(l, consumer.Connection)
         conn = l.connection = Mock()
         conn = l.connection = Mock()
-        comp.shutdown(l)
+        step.shutdown(l)
         self.assertTrue(conn.close.called)
         self.assertTrue(conn.close.called)
         self.assertIsNone(l.connection)
         self.assertIsNone(l.connection)
 
 
@@ -277,20 +291,21 @@ class test_Consumer(Case):
         eventer.enabled = True
         eventer.enabled = True
         heart = l.heart = MockHeart()
         heart = l.heart = MockHeart()
         l.namespace.state = RUN
         l.namespace.state = RUN
-        Events = find_component(l, consumer_components.Events)
+        Events = find_step(l, consumer.Events)
         Events.shutdown(l)
         Events.shutdown(l)
-        Heart = find_component(l, consumer_components.Heartbeat)
+        Heart = find_step(l, consumer.Heart)
         Heart.shutdown(l)
         Heart.shutdown(l)
         self.assertTrue(eventer.close.call_count)
         self.assertTrue(eventer.close.call_count)
         self.assertTrue(heart.closed)
         self.assertTrue(heart.closed)
 
 
     @patch('celery.worker.consumer.warn')
     @patch('celery.worker.consumer.warn')
     def test_receive_message_unknown(self, warn):
     def test_receive_message_unknown(self, warn):
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = _MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l.steps.pop()
         backend = Mock()
         backend = Mock()
         m = create_message(backend, unknown={'baz': '!!!'})
         m = create_message(backend, unknown={'baz': '!!!'})
         l.event_dispatcher = Mock()
         l.event_dispatcher = Mock()
-        l.pidbox_node = MockNode()
+        l.node = MockNode()
 
 
         callback = self._get_on_message(l)
         callback = self._get_on_message(l)
         callback(m.decode(), m)
         callback(m.decode(), m)
@@ -299,13 +314,14 @@ class test_Consumer(Case):
     @patch('celery.worker.consumer.to_timestamp')
     @patch('celery.worker.consumer.to_timestamp')
     def test_receive_message_eta_OverflowError(self, to_timestamp):
     def test_receive_message_eta_OverflowError(self, to_timestamp):
         to_timestamp.side_effect = OverflowError()
         to_timestamp.side_effect = OverflowError()
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = _MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l.steps.pop()
         m = create_message(Mock(), task=foo_task.name,
         m = create_message(Mock(), task=foo_task.name,
                                    args=('2, 2'),
                                    args=('2, 2'),
                                    kwargs={},
                                    kwargs={},
                                    eta=datetime.now().isoformat())
                                    eta=datetime.now().isoformat())
         l.event_dispatcher = Mock()
         l.event_dispatcher = Mock()
-        l.pidbox_node = MockNode()
+        l.node = MockNode()
         l.update_strategies()
         l.update_strategies()
 
 
         callback = self._get_on_message(l)
         callback = self._get_on_message(l)
@@ -315,7 +331,8 @@ class test_Consumer(Case):
 
 
     @patch('celery.worker.consumer.error')
     @patch('celery.worker.consumer.error')
     def test_receive_message_InvalidTaskError(self, error):
     def test_receive_message_InvalidTaskError(self, error):
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = _MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l.steps.pop()
         m = create_message(Mock(), task=foo_task.name,
         m = create_message(Mock(), task=foo_task.name,
                            args=(1, 2), kwargs='foobarbaz', id=1)
                            args=(1, 2), kwargs='foobarbaz', id=1)
         l.update_strategies()
         l.update_strategies()
@@ -327,7 +344,7 @@ class test_Consumer(Case):
 
 
     @patch('celery.worker.consumer.crit')
     @patch('celery.worker.consumer.crit')
     def test_on_decode_error(self, crit):
     def test_on_decode_error(self, crit):
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = Consumer(self.ready_queue, timer=self.timer)
 
 
         class MockMessage(Mock):
         class MockMessage(Mock):
             content_type = 'application/x-msgpack'
             content_type = 'application/x-msgpack'
@@ -352,7 +369,7 @@ class test_Consumer(Case):
         return l.task_consumer.register_callback.call_args[0][0]
         return l.task_consumer.register_callback.call_args[0][0]
 
 
     def test_receieve_message(self):
     def test_receieve_message(self):
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = Consumer(self.ready_queue, timer=self.timer)
         m = create_message(Mock(), task=foo_task.name,
         m = create_message(Mock(), task=foo_task.name,
                            args=[2, 4, 8], kwargs={})
                            args=[2, 4, 8], kwargs={})
         l.update_strategies()
         l.update_strategies()
@@ -430,7 +447,7 @@ class test_Consumer(Case):
                 self.obj.connection = None
                 self.obj.connection = None
                 raise socket.error('foo')
                 raise socket.error('foo')
 
 
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = Consumer(self.ready_queue, timer=self.timer)
         l.namespace.state = RUN
         l.namespace.state = RUN
         c = l.connection = Connection()
         c = l.connection = Connection()
         l.connection.obj = l
         l.connection.obj = l
@@ -451,7 +468,7 @@ class test_Consumer(Case):
             def drain_events(self, **kwargs):
             def drain_events(self, **kwargs):
                 self.obj.connection = None
                 self.obj.connection = None
 
 
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = Consumer(self.ready_queue, timer=self.timer)
         l.connection = Connection()
         l.connection = Connection()
         l.connection.obj = l
         l.connection.obj = l
         l.task_consumer = Mock()
         l.task_consumer = Mock()
@@ -469,15 +486,15 @@ class test_Consumer(Case):
         self.assertEqual(l.qos.value, 9)
         self.assertEqual(l.qos.value, 9)
         l.task_consumer.qos.assert_called_with(prefetch_count=9)
         l.task_consumer.qos.assert_called_with(prefetch_count=9)
 
 
-    def test_maybe_conn_error(self):
+    def test_ignore_errors(self):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         l.connection_errors = (KeyError, )
         l.connection_errors = (KeyError, )
         l.channel_errors = (SyntaxError, )
         l.channel_errors = (SyntaxError, )
-        l.maybe_conn_error(Mock(side_effect=AttributeError('foo')))
-        l.maybe_conn_error(Mock(side_effect=KeyError('foo')))
-        l.maybe_conn_error(Mock(side_effect=SyntaxError('foo')))
+        ignore_errors(l, Mock(side_effect=AttributeError('foo')))
+        ignore_errors(l, Mock(side_effect=KeyError('foo')))
+        ignore_errors(l, Mock(side_effect=SyntaxError('foo')))
         with self.assertRaises(IndexError):
         with self.assertRaises(IndexError):
-            l.maybe_conn_error(Mock(side_effect=IndexError('foo')))
+            ignore_errors(l, Mock(side_effect=IndexError('foo')))
 
 
     def test_apply_eta_task(self):
     def test_apply_eta_task(self):
         from celery.worker import state
         from celery.worker import state
@@ -492,7 +509,8 @@ class test_Consumer(Case):
         self.assertIs(self.ready_queue.get_nowait(), task)
         self.assertIs(self.ready_queue.get_nowait(), task)
 
 
     def test_receieve_message_eta_isoformat(self):
     def test_receieve_message_eta_isoformat(self):
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = _MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l.steps.pop()
         m = create_message(Mock(), task=foo_task.name,
         m = create_message(Mock(), task=foo_task.name,
                            eta=datetime.now().isoformat(),
                            eta=datetime.now().isoformat(),
                            args=[2, 4, 8], kwargs={})
                            args=[2, 4, 8], kwargs={})
@@ -517,29 +535,30 @@ class test_Consumer(Case):
         self.assertGreater(l.qos.value, current_pcount)
         self.assertGreater(l.qos.value, current_pcount)
         l.timer.stop()
         l.timer.stop()
 
 
-    def test_on_control(self):
+    def test_pidbox_callback(self):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
-        con = find_component(l, consumer_components.Controller)
-        con.pidbox_node = Mock()
-        con.reset_pidbox_node = Mock()
+        con = find_step(l, consumer.Control).box
+        con.node = Mock()
+        con.reset = Mock()
 
 
-        con.on_control('foo', 'bar')
-        con.pidbox_node.handle_message.assert_called_with('foo', 'bar')
+        con.on_message('foo', 'bar')
+        con.node.handle_message.assert_called_with('foo', 'bar')
 
 
-        con.pidbox_node = Mock()
-        con.pidbox_node.handle_message.side_effect = KeyError('foo')
-        con.on_control('foo', 'bar')
-        con.pidbox_node.handle_message.assert_called_with('foo', 'bar')
+        con.node = Mock()
+        con.node.handle_message.side_effect = KeyError('foo')
+        con.on_message('foo', 'bar')
+        con.node.handle_message.assert_called_with('foo', 'bar')
 
 
-        con.pidbox_node = Mock()
-        con.pidbox_node.handle_message.side_effect = ValueError('foo')
-        con.on_control('foo', 'bar')
-        con.pidbox_node.handle_message.assert_called_with('foo', 'bar')
-        con.reset_pidbox_node.assert_called_with()
+        con.node = Mock()
+        con.node.handle_message.side_effect = ValueError('foo')
+        con.on_message('foo', 'bar')
+        con.node.handle_message.assert_called_with('foo', 'bar')
+        self.assertTrue(con.reset.called)
 
 
     def test_revoke(self):
     def test_revoke(self):
         ready_queue = FastQueue()
         ready_queue = FastQueue()
-        l = MyKombuConsumer(ready_queue, timer=self.timer)
+        l = _MyKombuConsumer(ready_queue, timer=self.timer)
+        l.steps.pop()
         backend = Mock()
         backend = Mock()
         id = uuid()
         id = uuid()
         t = create_message(backend, task=foo_task.name, args=[2, 4, 8],
         t = create_message(backend, task=foo_task.name, args=[2, 4, 8],
@@ -552,7 +571,8 @@ class test_Consumer(Case):
         self.assertTrue(ready_queue.empty())
         self.assertTrue(ready_queue.empty())
 
 
     def test_receieve_message_not_registered(self):
     def test_receieve_message_not_registered(self):
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = _MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l.steps.pop()
         backend = Mock()
         backend = Mock()
         m = create_message(backend, task='x.X.31x', args=[2, 4, 8], kwargs={})
         m = create_message(backend, task='x.X.31x', args=[2, 4, 8], kwargs={})
 
 
@@ -566,7 +586,7 @@ class test_Consumer(Case):
     @patch('celery.worker.consumer.warn')
     @patch('celery.worker.consumer.warn')
     @patch('celery.worker.consumer.logger')
     @patch('celery.worker.consumer.logger')
     def test_receieve_message_ack_raises(self, logger, warn):
     def test_receieve_message_ack_raises(self, logger, warn):
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = Consumer(self.ready_queue, timer=self.timer)
         backend = Mock()
         backend = Mock()
         m = create_message(backend, args=[2, 4, 8], kwargs={})
         m = create_message(backend, args=[2, 4, 8], kwargs={})
 
 
@@ -584,7 +604,8 @@ class test_Consumer(Case):
         self.assertTrue(logger.critical.call_count)
         self.assertTrue(logger.critical.call_count)
 
 
     def test_receieve_message_eta(self):
     def test_receieve_message_eta(self):
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = _MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l.steps.pop()
         l.event_dispatcher = Mock()
         l.event_dispatcher = Mock()
         l.event_dispatcher._outbound_buffer = deque()
         l.event_dispatcher._outbound_buffer = deque()
         backend = Mock()
         backend = Mock()
@@ -600,7 +621,7 @@ class test_Consumer(Case):
             l.namespace.start(l)
             l.namespace.start(l)
         finally:
         finally:
             l.app.conf.BROKER_CONNECTION_RETRY = p
             l.app.conf.BROKER_CONNECTION_RETRY = p
-        l.restart()
+        l.namespace.restart(l)
         l.event_dispatcher = Mock()
         l.event_dispatcher = Mock()
         callback = self._get_on_message(l)
         callback = self._get_on_message(l)
         callback(m.decode(), m)
         callback(m.decode(), m)
@@ -617,27 +638,34 @@ class test_Consumer(Case):
 
 
     def test_reset_pidbox_node(self):
     def test_reset_pidbox_node(self):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
-        con = find_component(l, consumer_components.Controller)
-        con.pidbox_node = Mock()
-        chan = con.pidbox_node.channel = Mock()
+        con = find_step(l, consumer.Control).box
+        con.node = Mock()
+        chan = con.node.channel = Mock()
         l.connection = Mock()
         l.connection = Mock()
         chan.close.side_effect = socket.error('foo')
         chan.close.side_effect = socket.error('foo')
         l.connection_errors = (socket.error, )
         l.connection_errors = (socket.error, )
-        con.reset_pidbox_node()
+        con.reset()
         chan.close.assert_called_with()
         chan.close.assert_called_with()
 
 
     def test_reset_pidbox_node_green(self):
     def test_reset_pidbox_node_green(self):
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
-        con = find_component(l, consumer_components.Controller)
-        l.pool = Mock()
-        l.pool.is_green = True
-        con.reset_pidbox_node()
-        l.pool.spawn_n.assert_called_with(con._green_pidbox_node)
+        from celery.worker.pidbox import gPidbox
+        pool = Mock()
+        pool.is_green = True
+        l = MyKombuConsumer(self.ready_queue, timer=self.timer, pool=pool)
+        con = find_step(l, consumer.Control)
+        self.assertIsInstance(con.box, gPidbox)
+        con.start(l)
+        l.pool.spawn_n.assert_called_with(
+            con.box.loop, l,
+        )
 
 
     def test__green_pidbox_node(self):
     def test__green_pidbox_node(self):
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
-        l.pidbox_node = Mock()
-        cont = find_component(l, consumer_components.Controller)
+        pool = Mock()
+        pool.is_green = True
+        l = MyKombuConsumer(self.ready_queue, timer=self.timer, pool=pool)
+        l.node = Mock()
+        controller = find_step(l, consumer.Control)
+        box = controller.box
 
 
         class BConsumer(Mock):
         class BConsumer(Mock):
 
 
@@ -648,7 +676,7 @@ class test_Consumer(Case):
             def __exit__(self, *exc_info):
             def __exit__(self, *exc_info):
                 self.cancel()
                 self.cancel()
 
 
-        cont.pidbox_node.listen = BConsumer()
+        controller.box.node.listen = BConsumer()
         connections = []
         connections = []
 
 
         class Connection(object):
         class Connection(object):
@@ -677,26 +705,26 @@ class test_Consumer(Case):
                     self.calls += 1
                     self.calls += 1
                     raise socket.timeout()
                     raise socket.timeout()
                 self.obj.connection = None
                 self.obj.connection = None
-                cont._pidbox_node_shutdown.set()
+                controller.box._node_shutdown.set()
 
 
             def close(self):
             def close(self):
                 self.closed = True
                 self.closed = True
 
 
         l.connection = Mock()
         l.connection = Mock()
-        l._open_connection = lambda: Connection(obj=l)
-        controller = find_component(l, consumer_components.Controller)
-        controller._green_pidbox_node()
+        l.connect = lambda: Connection(obj=l)
+        controller = find_step(l, consumer.Control)
+        controller.box.loop(l)
 
 
-        cont.pidbox_node.listen.assert_called_with(callback=cont.on_control)
-        self.assertTrue(cont.broadcast_consumer)
-        cont.broadcast_consumer.consume.assert_called_with()
+        self.assertTrue(controller.box.node.listen.called)
+        self.assertTrue(controller.box.consumer)
+        controller.box.consumer.consume.assert_called_with()
 
 
         self.assertIsNone(l.connection)
         self.assertIsNone(l.connection)
         self.assertTrue(connections[0].closed)
         self.assertTrue(connections[0].closed)
 
 
     @patch('kombu.connection.Connection._establish_connection')
     @patch('kombu.connection.Connection._establish_connection')
     @patch('kombu.utils.sleep')
     @patch('kombu.utils.sleep')
-    def test_open_connection_errback(self, sleep, connect):
+    def test_connect_errback(self, sleep, connect):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         from kombu.transport.memory import Transport
         from kombu.transport.memory import Transport
         Transport.connection_errors = (StdChannelError, )
         Transport.connection_errors = (StdChannelError, )
@@ -706,16 +734,16 @@ class test_Consumer(Case):
                 return
                 return
             raise StdChannelError()
             raise StdChannelError()
         connect.side_effect = effect
         connect.side_effect = effect
-        l._open_connection()
+        l.connect()
         connect.assert_called_with()
         connect.assert_called_with()
 
 
     def test_stop_pidbox_node(self):
     def test_stop_pidbox_node(self):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
-        cont = find_component(l, consumer_components.Controller)
-        cont._pidbox_node_stopped = Event()
-        cont._pidbox_node_shutdown = Event()
-        cont._pidbox_node_stopped.set()
-        cont.stop_pidbox_node()
+        cont = find_step(l, consumer.Control)
+        cont._node_stopped = Event()
+        cont._node_shutdown = Event()
+        cont._node_stopped.set()
+        cont.stop(l)
 
 
     def test_start__loop(self):
     def test_start__loop(self):
 
 
@@ -752,7 +780,6 @@ class test_Consumer(Case):
         l.loop = raises_KeyError
         l.loop = raises_KeyError
         with self.assertRaises(KeyError):
         with self.assertRaises(KeyError):
             l.start()
             l.start()
-        self.assertTrue(init_callback.call_count)
         self.assertEqual(l.iterations, 2)
         self.assertEqual(l.iterations, 2)
         self.assertEqual(l.qos.prev, l.qos.value)
         self.assertEqual(l.qos.prev, l.qos.value)
 
 
@@ -766,11 +793,11 @@ class test_Consumer(Case):
         l.loop = Mock(side_effect=socket.error('foo'))
         l.loop = Mock(side_effect=socket.error('foo'))
         with self.assertRaises(socket.error):
         with self.assertRaises(socket.error):
             l.start()
             l.start()
-        self.assertTrue(init_callback.call_count)
         self.assertTrue(l.loop.call_count)
         self.assertTrue(l.loop.call_count)
 
 
     def test_reset_connection_with_no_node(self):
     def test_reset_connection_with_no_node(self):
         l = Consumer(self.ready_queue, timer=self.timer)
         l = Consumer(self.ready_queue, timer=self.timer)
+        l.steps.pop()
         self.assertEqual(None, l.pool)
         self.assertEqual(None, l.pool)
         l.namespace.start(l)
         l.namespace.start(l)
 
 
@@ -817,7 +844,7 @@ class test_WorkController(AppCase):
     def test_use_pidfile(self, create_pidlock):
     def test_use_pidfile(self, create_pidlock):
         create_pidlock.return_value = Mock()
         create_pidlock.return_value = Mock()
         worker = self.create_worker(pidfile='pidfilelockfilepid')
         worker = self.create_worker(pidfile='pidfilelockfilepid')
-        worker.components = []
+        worker.steps = []
         worker.start()
         worker.start()
         self.assertTrue(create_pidlock.called)
         self.assertTrue(create_pidlock.called)
         worker.stop()
         worker.stop()
@@ -864,12 +891,12 @@ class test_WorkController(AppCase):
         self.assertTrue(worker.pool)
         self.assertTrue(worker.pool)
         self.assertTrue(worker.consumer)
         self.assertTrue(worker.consumer)
         self.assertTrue(worker.mediator)
         self.assertTrue(worker.mediator)
-        self.assertTrue(worker.components)
+        self.assertTrue(worker.steps)
 
 
     def test_with_embedded_celerybeat(self):
     def test_with_embedded_celerybeat(self):
         worker = WorkController(concurrency=1, loglevel=0, beat=True)
         worker = WorkController(concurrency=1, loglevel=0, beat=True)
         self.assertTrue(worker.beat)
         self.assertTrue(worker.beat)
-        self.assertIn(worker.beat, [w.obj for w in worker.components])
+        self.assertIn(worker.beat, [w.obj for w in worker.steps])
 
 
     def test_with_autoscaler(self):
     def test_with_autoscaler(self):
         worker = self.create_worker(autoscale=[10, 3], send_events=False,
         worker = self.create_worker(autoscale=[10, 3], send_events=False,
@@ -931,7 +958,7 @@ class test_WorkController(AppCase):
         m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
         m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
                            kwargs={})
                            kwargs={})
         task = Request.from_message(m, m.decode())
         task = Request.from_message(m, m.decode())
-        worker.components = []
+        worker.steps = []
         worker.namespace.state = RUN
         worker.namespace.state = RUN
         with self.assertRaises(KeyboardInterrupt):
         with self.assertRaises(KeyboardInterrupt):
             worker.process_task(task)
             worker.process_task(task)
@@ -945,7 +972,7 @@ class test_WorkController(AppCase):
         m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
         m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
                            kwargs={})
                            kwargs={})
         task = Request.from_message(m, m.decode())
         task = Request.from_message(m, m.decode())
-        worker.components = []
+        worker.steps = []
         worker.namespace.state = RUN
         worker.namespace.state = RUN
         with self.assertRaises(SystemExit):
         with self.assertRaises(SystemExit):
             worker.process_task(task)
             worker.process_task(task)
@@ -964,17 +991,18 @@ class test_WorkController(AppCase):
 
 
     def test_start_catches_base_exceptions(self):
     def test_start_catches_base_exceptions(self):
         worker1 = self.create_worker()
         worker1 = self.create_worker()
-        stc = Mock()
+        stc = MockStep()
         stc.start.side_effect = SystemTerminate()
         stc.start.side_effect = SystemTerminate()
-        worker1.components = [stc]
+        worker1.steps = [stc]
         worker1.start()
         worker1.start()
+        stc.start.assert_called_with(worker1)
         self.assertTrue(stc.terminate.call_count)
         self.assertTrue(stc.terminate.call_count)
 
 
         worker2 = self.create_worker()
         worker2 = self.create_worker()
-        sec = Mock()
+        sec = MockStep()
         sec.start.side_effect = SystemExit()
         sec.start.side_effect = SystemExit()
         sec.terminate = None
         sec.terminate = None
-        worker2.components = [sec]
+        worker2.steps = [sec]
         worker2.start()
         worker2.start()
         self.assertTrue(sec.stop.call_count)
         self.assertTrue(sec.stop.call_count)
 
 
@@ -1027,18 +1055,18 @@ class test_WorkController(AppCase):
     def test_start__stop(self):
     def test_start__stop(self):
         worker = self.worker
         worker = self.worker
         worker.namespace.shutdown_complete.set()
         worker.namespace.shutdown_complete.set()
-        worker.components = [StartStopComponent(self) for _ in range(4)]
+        worker.steps = [MockStep(StartStopStep(self)) for _ in range(4)]
         worker.namespace.state = RUN
         worker.namespace.state = RUN
         worker.namespace.started = 4
         worker.namespace.started = 4
-        for w in worker.components:
+        for w in worker.steps:
             w.start = Mock()
             w.start = Mock()
             w.stop = Mock()
             w.stop = Mock()
 
 
         worker.start()
         worker.start()
-        for w in worker.components:
+        for w in worker.steps:
             self.assertTrue(w.start.call_count)
             self.assertTrue(w.start.call_count)
         worker.stop()
         worker.stop()
-        for w in worker.components:
+        for w in worker.steps:
             self.assertTrue(w.stop.call_count)
             self.assertTrue(w.stop.call_count)
 
 
         # Doesn't close pool if no pool.
         # Doesn't close pool if no pool.
@@ -1047,15 +1075,15 @@ class test_WorkController(AppCase):
         worker.stop()
         worker.stop()
 
 
         # test that stop of None is not attempted
         # test that stop of None is not attempted
-        worker.components[-1] = None
+        worker.steps[-1] = None
         worker.start()
         worker.start()
         worker.stop()
         worker.stop()
 
 
-    def test_component_raises(self):
+    def test_step_raises(self):
         worker = self.worker
         worker = self.worker
-        comp = Mock()
-        worker.components = [comp]
-        comp.start.side_effect = TypeError()
+        step = Mock()
+        worker.steps = [step]
+        step.start.side_effect = TypeError()
         worker.stop = Mock()
         worker.stop = Mock()
         worker.start()
         worker.start()
         worker.stop.assert_called_with()
         worker.stop.assert_called_with()
@@ -1068,16 +1096,15 @@ class test_WorkController(AppCase):
         worker.namespace.shutdown_complete.set()
         worker.namespace.shutdown_complete.set()
         worker.namespace.started = 5
         worker.namespace.started = 5
         worker.namespace.state = RUN
         worker.namespace.state = RUN
-        worker.components = [Mock(), Mock(), Mock(), Mock(), Mock()]
-
+        worker.steps = [MockStep() for _ in range(5)]
         worker.start()
         worker.start()
-        for w in worker.components[:3]:
+        for w in worker.steps[:3]:
             self.assertTrue(w.start.call_count)
             self.assertTrue(w.start.call_count)
-        self.assertTrue(worker.namespace.started, len(worker.components))
+        self.assertTrue(worker.namespace.started, len(worker.steps))
         self.assertEqual(worker.namespace.state, RUN)
         self.assertEqual(worker.namespace.state, RUN)
         worker.terminate()
         worker.terminate()
-        for component in worker.components:
-            self.assertTrue(component.terminate.call_count)
+        for step in worker.steps:
+            self.assertTrue(step.terminate.call_count)
 
 
     def test_Queues_pool_not_rlimit_safe(self):
     def test_Queues_pool_not_rlimit_safe(self):
         w = Mock()
         w = Mock()
@@ -1091,9 +1118,9 @@ class test_WorkController(AppCase):
         Queues(w).create(w)
         Queues(w).create(w)
         self.assertIs(w.ready_queue.put, w.process_task)
         self.assertIs(w.ready_queue.put, w.process_task)
 
 
-    def test_EvLoop_crate(self):
+    def test_Hub_crate(self):
         w = Mock()
         w = Mock()
-        x = EvLoop(w)
+        x = Hub(w)
         hub = x.create(w)
         hub = x.create(w)
         self.assertTrue(w.timer.max_interval)
         self.assertTrue(w.timer.max_interval)
         self.assertIs(w.hub, hub)
         self.assertIs(w.hub, hub)

+ 0 - 5
celery/utils/text.py

@@ -81,8 +81,3 @@ def pretty(value, width=80, nl_width=80, **kw):
         return '\n{0}{1}'.format(' ' * 4, pformat(value, width=nl_width, **kw))
         return '\n{0}{1}'.format(' ' * 4, pformat(value, width=nl_width, **kw))
     else:
     else:
         return pformat(value, width=width, **kw)
         return pformat(value, width=width, **kw)
-
-
-def dump_body(m, body):
-    return '{0} ({1}b)'.format(truncate(safe_repr(body), 1024),
-                               len(m.body))

+ 11 - 0
celery/utils/threads.py

@@ -9,10 +9,13 @@
 from __future__ import absolute_import, print_function
 from __future__ import absolute_import, print_function
 
 
 import os
 import os
+import socket
 import sys
 import sys
 import threading
 import threading
 import traceback
 import traceback
 
 
+from contextlib import contextmanager
+
 from celery.local import Proxy
 from celery.local import Proxy
 from celery.utils.compat import THREAD_TIMEOUT_MAX
 from celery.utils.compat import THREAD_TIMEOUT_MAX
 
 
@@ -284,6 +287,14 @@ class LocalManager(object):
             self.__class__.__name__, len(self.locals))
             self.__class__.__name__, len(self.locals))
 
 
 
 
+@contextmanager
+def default_socket_timeout(timeout):
+    prev = socket.getdefaulttimeout()
+    socket.setdefaulttimeout(timeout)
+    yield
+    socket.setdefaulttimeout(prev)
+
+
 class _FastLocalStack(threading.local):
 class _FastLocalStack(threading.local):
 
 
     def __init__(self):
     def __init__(self):

+ 23 - 23
celery/worker/__init__.py

@@ -43,24 +43,6 @@ enable the CELERY_CREATE_MISSING_QUEUES setting.
 """
 """
 
 
 
 
-class Namespace(bootsteps.Namespace):
-    """This is the boot-step namespace of the :class:`WorkController`.
-
-    It loads modules from :setting:`CELERYD_BOOT_STEPS`, and its
-    own set of built-in boot-step modules.
-
-    """
-    name = 'worker'
-    builtin_boot_steps = ('celery.worker.components',
-                          'celery.worker.autoscale',
-                          'celery.worker.autoreload',
-                          'celery.worker.consumer',
-                          'celery.worker.mediator')
-
-    def modules(self):
-        return self.builtin_boot_steps + self.app.conf.CELERYD_BOOT_STEPS
-
-
 class WorkController(configurated):
 class WorkController(configurated):
     """Unmanaged worker instance."""
     """Unmanaged worker instance."""
     app = None
     app = None
@@ -90,6 +72,24 @@ class WorkController(configurated):
 
 
     pidlock = None
     pidlock = None
 
 
+    class Namespace(bootsteps.Namespace):
+        """This is the boot-step namespace of the :class:`WorkController`.
+
+        It loads modules from :setting:`CELERYD_BOOT_STEPS`, and its
+        own set of built-in boot-step modules.
+
+        """
+        name = 'worker'
+        builtin_boot_steps = (
+            'celery.worker.components',
+            'celery.worker.autoscale',
+            'celery.worker.autoreload',
+            'celery.worker.mediator',
+        )
+
+        def modules(self):
+            return self.builtin_boot_steps + self.app.conf.CELERYD_BOOT_STEPS
+
     def __init__(self, app=None, hostname=None, **kwargs):
     def __init__(self, app=None, hostname=None, **kwargs):
         self.app = app_or_default(app or self.app)
         self.app = app_or_default(app or self.app)
         self.hostname = hostname or socket.gethostname()
         self.hostname = hostname or socket.gethostname()
@@ -122,12 +122,12 @@ class WorkController(configurated):
 
 
         # Initialize boot steps
         # Initialize boot steps
         self.pool_cls = _concurrency.get_implementation(self.pool_cls)
         self.pool_cls = _concurrency.get_implementation(self.pool_cls)
-        self.components = []
+        self.steps = []
         self.on_init_namespace()
         self.on_init_namespace()
-        self.namespace = Namespace(app=self.app,
-                                   on_start=self.on_start,
-                                   on_close=self.on_close,
-                                   on_stopped=self.on_stopped)
+        self.namespace = self.Namespace(app=self.app,
+                                        on_start=self.on_start,
+                                        on_close=self.on_close,
+                                        on_stopped=self.on_stopped)
         self.namespace.apply(self, **kwargs)
         self.namespace.apply(self, **kwargs)
 
 
     def on_init_namespace(self):
     def on_init_namespace(self):

+ 2 - 2
celery/worker/autoreload.py

@@ -23,7 +23,7 @@ from celery.utils.imports import module_file
 from celery.utils.log import get_logger
 from celery.utils.log import get_logger
 from celery.utils.threads import bgThread
 from celery.utils.threads import bgThread
 
 
-from .bootsteps import StartStopComponent
+from . import bootsteps
 
 
 try:                        # pragma: no cover
 try:                        # pragma: no cover
     import pyinotify
     import pyinotify
@@ -35,7 +35,7 @@ except ImportError:         # pragma: no cover
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
-class WorkerComponent(StartStopComponent):
+class WorkerComponent(bootsteps.StartStopStep):
     name = 'worker.autoreloader'
     name = 'worker.autoreloader'
     requires = ('pool', )
     requires = ('pool', )
 
 

+ 2 - 2
celery/worker/autoscale.py

@@ -21,15 +21,15 @@ from time import sleep, time
 from celery.utils.log import get_logger
 from celery.utils.log import get_logger
 from celery.utils.threads import bgThread
 from celery.utils.threads import bgThread
 
 
+from . import bootsteps
 from . import state
 from . import state
-from .bootsteps import StartStopComponent
 from .hub import DummyLock
 from .hub import DummyLock
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 debug, info, error = logger.debug, logger.info, logger.error
 debug, info, error = logger.debug, logger.info, logger.error
 
 
 
 
-class WorkerComponent(StartStopComponent):
+class WorkerComponent(bootsteps.StartStopStep):
     name = 'worker.autoscaler'
     name = 'worker.autoscaler'
     requires = ('pool', )
     requires = ('pool', )
 
 

+ 70 - 70
celery/worker/bootsteps.py

@@ -3,20 +3,19 @@
     celery.worker.bootsteps
     celery.worker.bootsteps
     ~~~~~~~~~~~~~~~~~~~~~~~
     ~~~~~~~~~~~~~~~~~~~~~~~
 
 
-    The boot-step components.
+    The boot-steps!
 
 
 """
 """
 from __future__ import absolute_import
 from __future__ import absolute_import
 
 
-import socket
-
 from collections import defaultdict
 from collections import defaultdict
 from importlib import import_module
 from importlib import import_module
 from threading import Event
 from threading import Event
 
 
 from celery.datastructures import DependencyGraph
 from celery.datastructures import DependencyGraph
-from celery.utils.imports import instantiate, qualname
+from celery.utils.imports import instantiate
 from celery.utils.log import get_logger
 from celery.utils.log import get_logger
+from celery.utils.threads import default_socket_timeout
 
 
 try:
 try:
     from greenlet import GreenletExit
     from greenlet import GreenletExit
@@ -35,13 +34,17 @@ TERMINATE = 0x3
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
+def qualname(c):
+    return '.'.join([c.namespace.name, c.name.capitalize()])
+
+
 class Namespace(object):
 class Namespace(object):
-    """A namespace containing components.
+    """A namespace containing bootsteps.
 
 
-    Every component must belong to a namespace.
+    Every step must belong to a namespace.
 
 
-    When component classes are created they are added to the
-    mapping of unclaimed components.  The components will be
+    When step classes are created they are added to the
+    mapping of unclaimed steps.  The steps will be
     claimed when the namespace they belong to is created.
     claimed when the namespace they belong to is created.
 
 
     :keyword name: Set the name of this namespace.
     :keyword name: Set the name of this namespace.
@@ -68,35 +71,30 @@ class Namespace(object):
         self.state = RUN
         self.state = RUN
         if self.on_start:
         if self.on_start:
             self.on_start()
             self.on_start()
-        for i, component in enumerate(parent.components):
-            if component:
-                logger.debug('Starting %s...', qualname(component))
+        for i, step in enumerate(parent.steps):
+            if step:
+                logger.debug('Starting %s...', qualname(step))
                 self.started = i + 1
                 self.started = i + 1
-                component.start(parent)
-                logger.debug('%s OK!', qualname(component))
+                print('STARTING: %r' % (step.start, ))
+                step.start(parent)
+                logger.debug('%s OK!', qualname(step))
 
 
     def close(self, parent):
     def close(self, parent):
         if self.on_close:
         if self.on_close:
             self.on_close()
             self.on_close()
-        for component in parent.components:
-            try:
-                close = component.close
-            except AttributeError:
-                pass
-            else:
+        for step in parent.steps:
+            close = getattr(step, 'close', None)
+            if close:
                 close(parent)
                 close(parent)
 
 
-    def restart(self, parent, description='Restarting', terminate=False):
-        socket_timeout = socket.getdefaulttimeout()
-        socket.setdefaulttimeout(SHUTDOWN_SOCKET_TIMEOUT)  # Issue 975
-        try:
-            for component in reversed(parent.components):
-                if component:
-                    logger.debug('%s %s...', description, qualname(component))
-                    (component.terminate if terminate
-                        else component.stop)(parent)
-        finally:
-            socket.setdefaulttimeout(socket_timeout)
+    def restart(self, parent, description='Restarting', attr='stop'):
+        with default_socket_timeout(SHUTDOWN_SOCKET_TIMEOUT):  # Issue 975
+            for step in reversed(parent.steps):
+                if step:
+                    logger.debug('%s %s...', description, qualname(step))
+                    fun = getattr(step, attr, None)
+                    if fun:
+                        fun(parent)
 
 
     def stop(self, parent, close=True, terminate=False):
     def stop(self, parent, close=True, terminate=False):
         what = 'Terminating' if terminate else 'Stopping'
         what = 'Terminating' if terminate else 'Stopping'
@@ -105,13 +103,13 @@ class Namespace(object):
 
 
         self.close(parent)
         self.close(parent)
 
 
-        if self.state != RUN or self.started != len(parent.components):
+        if self.state != RUN or self.started != len(parent.steps):
             # Not fully started, can safely exit.
             # Not fully started, can safely exit.
             self.state = TERMINATE
             self.state = TERMINATE
             self.shutdown_complete.set()
             self.shutdown_complete.set()
             return
             return
         self.state = CLOSE
         self.state = CLOSE
-        self.restart(parent, what, terminate)
+        self.restart(parent, what, 'terminate' if terminate else 'stop')
 
 
         if self.on_stopped:
         if self.on_stopped:
             self.on_stopped()
             self.on_stopped()
@@ -128,40 +126,40 @@ class Namespace(object):
 
 
     def modules(self):
     def modules(self):
         """Subclasses can override this to return a
         """Subclasses can override this to return a
-        list of modules to import before components are claimed."""
+        list of modules to import before steps are claimed."""
         return []
         return []
 
 
     def load_modules(self):
     def load_modules(self):
-        """Will load the component modules this namespace depends on."""
+        """Will load the steps modules this namespace depends on."""
         for m in self.modules():
         for m in self.modules():
             self.import_module(m)
             self.import_module(m)
 
 
     def apply(self, parent, **kwargs):
     def apply(self, parent, **kwargs):
-        """Apply the components in this namespace to an object.
+        """Apply the steps in this namespace to an object.
 
 
         This will apply the ``__init__`` and ``include`` methods
         This will apply the ``__init__`` and ``include`` methods
-        of each components with the object as argument.
+        of each steps with the object as argument.
 
 
-        For ``StartStopComponents`` the services created
-        will also be added the the objects ``components`` attribute.
+        For :class:`StartStopStep` the services created
+        will also be added the the objects ``steps`` attribute.
 
 
         """
         """
         self._debug('Loading modules.')
         self._debug('Loading modules.')
         self.load_modules()
         self.load_modules()
-        self._debug('Claiming components.')
-        self.components = self._claim()
+        self._debug('Claiming steps.')
+        self.steps = self._claim()
         self._debug('Building boot step graph.')
         self._debug('Building boot step graph.')
-        self.boot_steps = [self.bind_component(name, parent, **kwargs)
+        self.boot_steps = [self.bind_step(name, parent, **kwargs)
                                 for name in self._finalize_boot_steps()]
                                 for name in self._finalize_boot_steps()]
         self._debug('New boot order: {%s}',
         self._debug('New boot order: {%s}',
                 ', '.join(c.name for c in self.boot_steps))
                 ', '.join(c.name for c in self.boot_steps))
 
 
-        for component in self.boot_steps:
-            component.include(parent)
+        for step in self.boot_steps:
+            step.include(parent)
         return self
         return self
 
 
-    def bind_component(self, name, parent, **kwargs):
-        """Bind component to parent object and this namespace."""
+    def bind_step(self, name, parent, **kwargs):
+        """Bind step to parent object and this namespace."""
         comp = self[name](parent, **kwargs)
         comp = self[name](parent, **kwargs)
         comp.namespace = self
         comp.namespace = self
         return comp
         return comp
@@ -170,16 +168,16 @@ class Namespace(object):
         return import_module(module)
         return import_module(module)
 
 
     def __getitem__(self, name):
     def __getitem__(self, name):
-        return self.components[name]
+        return self.steps[name]
 
 
     def _find_last(self):
     def _find_last(self):
-        for C in self.components.itervalues():
+        for C in self.steps.itervalues():
             if C.last:
             if C.last:
                 return C
                 return C
 
 
     def _finalize_boot_steps(self):
     def _finalize_boot_steps(self):
         G = self.graph = DependencyGraph((C.name, C.requires)
         G = self.graph = DependencyGraph((C.name, C.requires)
-                            for C in self.components.itervalues())
+                            for C in self.steps.itervalues())
         last = self._find_last()
         last = self._find_last()
         if last:
         if last:
             for obj in G:
             for obj in G:
@@ -201,8 +199,8 @@ def _prepare_requires(req):
     return req
     return req
 
 
 
 
-class ComponentType(type):
-    """Metaclass for components."""
+class StepType(type):
+    """Metaclass for steps."""
 
 
     def __new__(cls, name, bases, attrs):
     def __new__(cls, name, bases, attrs):
         abstract = attrs.pop('abstract', False)
         abstract = attrs.pop('abstract', False)
@@ -210,47 +208,47 @@ class ComponentType(type):
             try:
             try:
                 cname = attrs['name']
                 cname = attrs['name']
             except KeyError:
             except KeyError:
-                raise NotImplementedError('Components must be named')
+                raise NotImplementedError('Steps must be named')
             namespace = attrs.get('namespace', None)
             namespace = attrs.get('namespace', None)
             if not namespace:
             if not namespace:
                 attrs['namespace'], _, attrs['name'] = cname.partition('.')
                 attrs['namespace'], _, attrs['name'] = cname.partition('.')
         attrs['requires'] = tuple(_prepare_requires(req)
         attrs['requires'] = tuple(_prepare_requires(req)
                                     for req in attrs.get('requires', ()))
                                     for req in attrs.get('requires', ()))
-        cls = super(ComponentType, cls).__new__(cls, name, bases, attrs)
+        cls = super(StepType, cls).__new__(cls, name, bases, attrs)
         if not abstract:
         if not abstract:
             Namespace._unclaimed[cls.namespace][cls.name] = cls
             Namespace._unclaimed[cls.namespace][cls.name] = cls
         return cls
         return cls
 
 
 
 
-class Component(object):
-    """A component.
+class Step(object):
+    """A Bootstep.
 
 
-    The :meth:`__init__` method is called when the component
+    The :meth:`__init__` method is called when the step
     is bound to a parent object, and can as such be used
     is bound to a parent object, and can as such be used
     to initialize attributes in the parent object at
     to initialize attributes in the parent object at
     parent instantiation-time.
     parent instantiation-time.
 
 
     """
     """
-    __metaclass__ = ComponentType
+    __metaclass__ = StepType
 
 
-    #: The name of the component, or the namespace
-    #: and the name of the component separated by dot.
+    #: The name of the step, or the namespace
+    #: and the name of the step separated by dot.
     name = None
     name = None
 
 
-    #: List of component names this component depends on.
-    #: Note that the dependencies must be in the same namespace.
+    #: List of other steps that that must be started before this step.
+    #: Note that all dependencies must be in the same namespace.
     requires = ()
     requires = ()
 
 
     #: can be used to specify the namespace,
     #: can be used to specify the namespace,
     #: if the name does not include it.
     #: if the name does not include it.
     namespace = None
     namespace = None
 
 
-    #: if set the component will not be registered,
-    #: but can be used as a component base class.
+    #: if set the step will not be registered,
+    #: but can be used as a base class.
     abstract = True
     abstract = True
 
 
     #: Optional obj created by the :meth:`create` method.
     #: Optional obj created by the :meth:`create` method.
-    #: This is used by StartStopComponents to keep the
+    #: This is used by :class:`StartStopStep` to keep the
     #: original service object.
     #: original service object.
     obj = None
     obj = None
 
 
@@ -267,12 +265,12 @@ class Component(object):
         pass
         pass
 
 
     def create(self, parent):
     def create(self, parent):
-        """Create the component."""
+        """Create the step."""
         pass
         pass
 
 
     def include_if(self, parent):
     def include_if(self, parent):
         """An optional predicate that decided whether this
         """An optional predicate that decided whether this
-        component should be created."""
+        step should be created."""
         return self.enabled
         return self.enabled
 
 
     def instantiate(self, qualname, *args, **kwargs):
     def instantiate(self, qualname, *args, **kwargs):
@@ -284,14 +282,16 @@ class Component(object):
             return True
             return True
 
 
 
 
-class StartStopComponent(Component):
+class StartStopStep(Step):
     abstract = True
     abstract = True
 
 
     def start(self, parent):
     def start(self, parent):
-        return self.obj.start()
+        if self.obj:
+            return self.obj.start()
 
 
     def stop(self, parent):
     def stop(self, parent):
-        return self.obj.stop()
+        if self.obj:
+            return self.obj.stop()
 
 
     def close(self, parent):
     def close(self, parent):
         pass
         pass
@@ -300,5 +300,5 @@ class StartStopComponent(Component):
         self.stop(parent)
         self.stop(parent)
 
 
     def include(self, parent):
     def include(self, parent):
-        if super(StartStopComponent, self).include(parent):
-            parent.components.append(self)
+        if super(StartStopStep, self).include(parent):
+            parent.steps.append(self)

+ 73 - 54
celery/worker/components.py

@@ -18,13 +18,53 @@ from billiard.exceptions import WorkerLostError
 from celery.utils.log import worker_logger as logger
 from celery.utils.log import worker_logger as logger
 from celery.utils.timer2 import Schedule
 from celery.utils.timer2 import Schedule
 
 
-from . import bootsteps
+from . import bootsteps, hub
 from .buckets import TaskBucket, FastQueue
 from .buckets import TaskBucket, FastQueue
-from .hub import Hub, BoundedSemaphore
 
 
 
 
-class Pool(bootsteps.StartStopComponent):
-    """The pool component.
+class Hub(bootsteps.StartStopStep):
+    name = 'worker.hub'
+
+    def __init__(self, w, **kwargs):
+        w.hub = None
+
+    def include_if(self, w):
+        return w.use_eventloop
+
+    def create(self, w):
+        w.timer = Schedule(max_interval=10)
+        w.hub = hub.Hub(w.timer)
+        return w.hub
+
+
+class Queues(bootsteps.Step):
+    """This step initializes the internal queues
+    used by the worker."""
+    name = 'worker.queues'
+    requires = (Hub, )
+
+    def create(self, w):
+        w.start_mediator = True
+        if not w.pool_cls.rlimit_safe:
+            w.disable_rate_limits = True
+        if w.disable_rate_limits:
+            w.ready_queue = FastQueue()
+            if w.use_eventloop:
+                w.start_mediator = False
+                if w.pool_putlocks and w.pool_cls.uses_semaphore:
+                    w.ready_queue.put = w.process_task_sem
+                else:
+                    w.ready_queue.put = w.process_task
+            elif not w.pool_cls.requires_mediator:
+                # just send task directly to pool, skip the mediator.
+                w.ready_queue.put = w.process_task
+                w.start_mediator = False
+        else:
+            w.ready_queue = TaskBucket(task_registry=w.app.tasks)
+
+
+class Pool(bootsteps.StartStopStep):
+    """The pool step.
 
 
     Describes how to initialize the worker pool, and starts and stops
     Describes how to initialize the worker pool, and starts and stops
     the pool during worker startup/shutdown.
     the pool during worker startup/shutdown.
@@ -38,7 +78,7 @@ class Pool(bootsteps.StartStopComponent):
 
 
     """
     """
     name = 'worker.pool'
     name = 'worker.pool'
-    requires = ('queues', )
+    requires = (Queues, )
 
 
     def __init__(self, w, autoscale=None, autoreload=None,
     def __init__(self, w, autoscale=None, autoreload=None,
             no_execv=False, **kwargs):
             no_execv=False, **kwargs):
@@ -115,7 +155,7 @@ class Pool(bootsteps.StartStopComponent):
         procs = w.min_concurrency
         procs = w.min_concurrency
         forking_enable = not threaded or (w.no_execv or not w.force_execv)
         forking_enable = not threaded or (w.no_execv or not w.force_execv)
         if not threaded:
         if not threaded:
-            semaphore = w.semaphore = BoundedSemaphore(procs)
+            semaphore = w.semaphore = hub.BoundedSemaphore(procs)
             w._quick_acquire = w.semaphore.acquire
             w._quick_acquire = w.semaphore.acquire
             w._quick_release = w.semaphore.release
             w._quick_release = w.semaphore.release
             max_restarts = 100
             max_restarts = 100
@@ -137,8 +177,8 @@ class Pool(bootsteps.StartStopComponent):
         return pool
         return pool
 
 
 
 
-class Beat(bootsteps.StartStopComponent):
-    """Component used to embed a celerybeat process.
+class Beat(bootsteps.StartStopStep):
+    """Step used to embed a celerybeat process.
 
 
     This will only be enabled if the ``beat``
     This will only be enabled if the ``beat``
     argument is set.
     argument is set.
@@ -158,51 +198,10 @@ class Beat(bootsteps.StartStopComponent):
         return b
         return b
 
 
 
 
-class Queues(bootsteps.Component):
-    """This component initializes the internal queues
-    used by the worker."""
-    name = 'worker.queues'
-    requires = ('ev', )
-
-    def create(self, w):
-        w.start_mediator = True
-        if not w.pool_cls.rlimit_safe:
-            w.disable_rate_limits = True
-        if w.disable_rate_limits:
-            w.ready_queue = FastQueue()
-            if w.use_eventloop:
-                w.start_mediator = False
-                if w.pool_putlocks and w.pool_cls.uses_semaphore:
-                    w.ready_queue.put = w.process_task_sem
-                else:
-                    w.ready_queue.put = w.process_task
-            elif not w.pool_cls.requires_mediator:
-                # just send task directly to pool, skip the mediator.
-                w.ready_queue.put = w.process_task
-                w.start_mediator = False
-        else:
-            w.ready_queue = TaskBucket(task_registry=w.app.tasks)
-
-
-class EvLoop(bootsteps.StartStopComponent):
-    name = 'worker.ev'
-
-    def __init__(self, w, **kwargs):
-        w.hub = None
-
-    def include_if(self, w):
-        return w.use_eventloop
-
-    def create(self, w):
-        w.timer = Schedule(max_interval=10)
-        hub = w.hub = Hub(w.timer)
-        return hub
-
-
-class Timers(bootsteps.Component):
-    """This component initializes the internal timers used by the worker."""
+class Timers(bootsteps.Step):
+    """This step initializes the internal timers used by the worker."""
     name = 'worker.timers'
     name = 'worker.timers'
-    requires = ('pool', )
+    requires = (Pool, )
 
 
     def include_if(self, w):
     def include_if(self, w):
         return not w.use_eventloop
         return not w.use_eventloop
@@ -224,8 +223,8 @@ class Timers(bootsteps.Component):
         logger.debug('Timer wake-up! Next eta %s secs.', delay)
         logger.debug('Timer wake-up! Next eta %s secs.', delay)
 
 
 
 
-class StateDB(bootsteps.Component):
-    """This component sets up the workers state db if enabled."""
+class StateDB(bootsteps.Step):
+    """This step sets up the workers state db if enabled."""
     name = 'worker.state-db'
     name = 'worker.state-db'
 
 
     def __init__(self, w, **kwargs):
     def __init__(self, w, **kwargs):
@@ -235,3 +234,23 @@ class StateDB(bootsteps.Component):
     def create(self, w):
     def create(self, w):
         w._persistence = w.state.Persistent(w.state_db)
         w._persistence = w.state.Persistent(w.state_db)
         atexit.register(w._persistence.save)
         atexit.register(w._persistence.save)
+
+
+class Consumer(bootsteps.StartStopStep):
+    name = 'worker.consumer'
+    last = True
+
+    def create(self, w):
+        prefetch_count = w.concurrency * w.prefetch_multiplier
+        c = w.consumer = self.instantiate(w.consumer_cls,
+                w.ready_queue,
+                hostname=w.hostname,
+                send_events=w.send_events,
+                init_callback=w.ready_callback,
+                initial_prefetch_count=prefetch_count,
+                pool=w.pool,
+                timer=w.timer,
+                app=w.app,
+                controller=w,
+                hub=w.hub)
+        return c

+ 169 - 131
celery/worker/consumer/__init__.py → celery/worker/consumer.py

@@ -3,7 +3,7 @@
 celery.worker.consumer
 celery.worker.consumer
 ~~~~~~~~~~~~~~~~~~~~~~
 ~~~~~~~~~~~~~~~~~~~~~~
 
 
-This module contains the component responsible for consuming messages
+This module contains the components responsible for consuming messages
 from the broker, processing the messages and keeping the broker connections
 from the broker, processing the messages and keeping the broker connections
 up and running.
 up and running.
 
 
@@ -13,6 +13,7 @@ from __future__ import absolute_import
 import logging
 import logging
 import socket
 import socket
 
 
+from kombu.common import QoS, ignore_errors
 from kombu.syn import _detect_environment
 from kombu.syn import _detect_environment
 from kombu.utils.encoding import safe_repr
 from kombu.utils.encoding import safe_repr
 
 
@@ -20,20 +21,17 @@ from celery.app import app_or_default
 from celery.task.trace import build_tracer
 from celery.task.trace import build_tracer
 from celery.utils.timer2 import default_timer, to_timestamp
 from celery.utils.timer2 import default_timer, to_timestamp
 from celery.utils.functional import noop
 from celery.utils.functional import noop
-from celery.utils.imports import qualname
 from celery.utils.log import get_logger
 from celery.utils.log import get_logger
-from celery.utils.text import dump_body
+from celery.utils.text import truncate
 from celery.utils.timeutils import humanize_seconds, timezone
 from celery.utils.timeutils import humanize_seconds, timezone
-from celery.worker import state
-from celery.worker.state import maybe_shutdown
-from celery.worker.bootsteps import Namespace as _NS, StartStopComponent, CLOSE
 
 
-from . import loops
+from . import bootsteps, heartbeat, loops, pidbox
+from .state import task_reserved, maybe_shutdown
 
 
+CLOSE = bootsteps.CLOSE
 logger = get_logger(__name__)
 logger = get_logger(__name__)
-info, warn, error, crit = (logger.info, logger.warn,
-                           logger.error, logger.critical)
-task_reserved = state.task_reserved
+debug, info, warn, error, crit = (logger.debug, logger.info, logger.warn,
+                                  logger.error, logger.critical)
 
 
 CONNECTION_RETRY = """\
 CONNECTION_RETRY = """\
 consumer: Connection to broker lost. \
 consumer: Connection to broker lost. \
@@ -89,74 +87,20 @@ body: {0} {{content_type:{1} content_encoding:{2} delivery_info:{3}}}\
 """
 """
 
 
 
 
-def debug(msg, *args, **kwargs):
-    logger.debug('consumer: {0}'.format(msg), *args, **kwargs)
-
-
-class Component(StartStopComponent):
-    name = 'worker.consumer'
-    last = True
-
-    def create(self, w):
-        prefetch_count = w.concurrency * w.prefetch_multiplier
-        c = w.consumer = self.instantiate(w.consumer_cls,
-                w.ready_queue,
-                hostname=w.hostname,
-                send_events=w.send_events,
-                init_callback=w.ready_callback,
-                initial_prefetch_count=prefetch_count,
-                pool=w.pool,
-                timer=w.timer,
-                app=w.app,
-                controller=w,
-                hub=w.hub)
-        return c
-
-
-class Namespace(_NS):
-    name = 'consumer'
-    builtin_boot_steps = ('celery.worker.consumer.components', )
-
-    def shutdown(self, parent):
-        delayed = self._shutdown_step(parent, parent.components, force=False)
-        self._shutdown_step(parent, delayed, force=True)
-
-    def _shutdown_step(self, parent, components, force=False):
-        delayed = []
-        for component in components:
-            if component:
-                logger.debug('Shutdown %s...', qualname(component))
-                if not force and getattr(component, 'delay_shutdown', False):
-                    delayed.append(component)
-                else:
-                    component.shutdown(parent)
-        return delayed
-
-    def modules(self):
-        return (self.builtin_boot_steps +
-                self.app.conf.CELERYD_CONSUMER_BOOT_STEPS)
+def dump_body(m, body):
+    return '{0} ({1}b)'.format(truncate(safe_repr(body), 1024),
+                               len(m.body))
 
 
 
 
 class Consumer(object):
 class Consumer(object):
-    """Listen for messages received from the broker and
-    move them to the ready queue for task processing.
-
-    :param ready_queue: See :attr:`ready_queue`.
-    :param timer: See :attr:`timer`.
 
 
-    """
-
-    #: The queue that holds tasks ready for immediate processing.
+    #: Intra-queue for tasks ready to be handled
     ready_queue = None
     ready_queue = None
 
 
-    #: Optional callback to be called when the connection is established.
-    #: Will only be called once, even if the connection is lost and
-    #: re-established.
+    #: Optional callback called the first time the worker
+    #: is ready to receive tasks.
     init_callback = None
     init_callback = None
 
 
-    #: The current hostname.  Defaults to the system hostname.
-    hostname = None
-
     #: The current worker pool instance.
     #: The current worker pool instance.
     pool = None
     pool = None
 
 
@@ -164,6 +108,15 @@ class Consumer(object):
     #: as sending heartbeats.
     #: as sending heartbeats.
     timer = None
     timer = None
 
 
+    class Namespace(bootsteps.Namespace):
+        name = 'consumer'
+
+        def shutdown(self, parent):
+            self.restart(parent, 'Shutdown', 'shutdown')
+
+        def modules(self):
+            return self.app.conf.CELERYD_CONSUMER_BOOT_STEPS
+
     def __init__(self, ready_queue,
     def __init__(self, ready_queue,
             init_callback=noop, hostname=None,
             init_callback=noop, hostname=None,
             pool=None, app=None,
             pool=None, app=None,
@@ -182,52 +135,58 @@ class Consumer(object):
         self.channel_errors = conninfo.channel_errors
         self.channel_errors = conninfo.channel_errors
 
 
         self._does_info = logger.isEnabledFor(logging.INFO)
         self._does_info = logger.isEnabledFor(logging.INFO)
-        if not hasattr(self, 'loop'):
-            self.loop = loops.asynloop if hub else loops.synloop
-        if hub:
-            hub.on_init.append(self.on_poll_init)
-        self.hub = hub
         self._quick_put = self.ready_queue.put
         self._quick_put = self.ready_queue.put
-        self.amqheartbeat = amqheartbeat
-        if self.amqheartbeat is None:
-            self.amqheartbeat = self.app.conf.BROKER_HEARTBEAT
-        if not hub:
+
+        if hub:
+            self.amqheartbeat = amqheartbeat
+            if self.amqheartbeat is None:
+                self.amqheartbeat = self.app.conf.BROKER_HEARTBEAT
+            self.hub = hub
+            self.hub.on_init.append(self.on_poll_init)
+        else:
+            self.hub = None
             self.amqheartbeat = 0
             self.amqheartbeat = 0
 
 
+        if not hasattr(self, 'loop'):
+            self.loop = loops.asynloop if hub else loops.synloop
+
         if _detect_environment() == 'gevent':
         if _detect_environment() == 'gevent':
             # there's a gevent bug that causes timeouts to not be reset,
             # there's a gevent bug that causes timeouts to not be reset,
             # so if the connection timeout is exceeded once, it can NEVER
             # so if the connection timeout is exceeded once, it can NEVER
             # connect again.
             # connect again.
             self.app.conf.BROKER_CONNECTION_TIMEOUT = None
             self.app.conf.BROKER_CONNECTION_TIMEOUT = None
 
 
-        self.components = []
-        self.namespace = Namespace(app=self.app,
-                                   on_start=self.on_start,
-                                   on_close=self.on_close)
+        self.steps = []
+        self.namespace = self.Namespace(
+            app=self.app, on_start=self.on_start, on_close=self.on_close,
+        )
         self.namespace.apply(self, **kwargs)
         self.namespace.apply(self, **kwargs)
 
 
-    def on_start(self):
-        # reload all task's execution strategies.
-        self.update_strategies()
-        self.init_callback(self)
-
     def start(self):
     def start(self):
-        """Start the consumer.
-
-        Automatically survives intermittent connection failure,
-        and will retry establishing the connection and restart
-        consuming messages.
-
-        """
-        ns, loop, loop_args = self.namespace, self.loop, self.loop_args()
+        ns, loop = self.namespace, self.loop
         while ns.state != CLOSE:
         while ns.state != CLOSE:
             maybe_shutdown()
             maybe_shutdown()
             try:
             try:
                 ns.start(self)
                 ns.start(self)
-                loop(*loop_args)
             except self.connection_errors + self.channel_errors:
             except self.connection_errors + self.channel_errors:
-                error(CONNECTION_RETRY, exc_info=True)
-                self.restart()
+                maybe_shutdown()
+                if ns.state != CLOSE and self.connection:
+                    error(CONNECTION_RETRY, exc_info=True)
+                    ns.restart(self)
+
+    def shutdown(self):
+        self.namespace.shutdown(self)
+
+    def stop(self):
+        self.namespace.stop(self)
+
+    def on_start(self):
+        self.update_strategies()
+
+    def on_ready(self):
+        callback, self.init_callback = self.init_callback, None
+        if callback:
+            callback(self)
 
 
     def loop_args(self):
     def loop_args(self):
         return (self, self.connection, self.task_consumer,
         return (self, self.connection, self.task_consumer,
@@ -235,26 +194,10 @@ class Consumer(object):
                 self.amqheartbeat, self.handle_unknown_message,
                 self.amqheartbeat, self.handle_unknown_message,
                 self.handle_unknown_task, self.handle_invalid_task)
                 self.handle_unknown_task, self.handle_invalid_task)
 
 
-    def restart(self):
-        return self.namespace.restart(self)
-
     def on_poll_init(self, hub):
     def on_poll_init(self, hub):
         hub.update_readers(self.connection.eventmap)
         hub.update_readers(self.connection.eventmap)
         self.connection.transport.on_poll_init(hub.poller)
         self.connection.transport.on_poll_init(hub.poller)
 
 
-    def maybe_conn_error(self, fun):
-        """Applies function but ignores any connection or channel
-        errors raised."""
-        try:
-            fun()
-        except (AttributeError, ) + \
-                self.connection_errors + \
-                self.channel_errors:
-            pass
-
-    def shutdown(self):
-        self.namespace.shutdown(self)
-
     def on_decode_error(self, message, exc):
     def on_decode_error(self, message, exc):
         """Callback called if an error occurs while decoding
         """Callback called if an error occurs while decoding
         a message received.
         a message received.
@@ -278,7 +221,7 @@ class Consumer(object):
         self.ready_queue.clear()
         self.ready_queue.clear()
         self.timer.clear()
         self.timer.clear()
 
 
-    def _open_connection(self):
+    def connect(self):
         """Establish the broker connection.
         """Establish the broker connection.
 
 
         Will retry establishing the connection if the
         Will retry establishing the connection if the
@@ -306,15 +249,6 @@ class Consumer(object):
                     self.app.conf.BROKER_CONNECTION_MAX_RETRIES,
                     self.app.conf.BROKER_CONNECTION_MAX_RETRIES,
                     callback=maybe_shutdown)
                     callback=maybe_shutdown)
 
 
-    def stop(self):
-        """Stop consuming.
-
-        Does not close the broker connection, so be sure to call
-        :meth:`close_connection` when you are finished with it.
-
-        """
-        self.namespace.stop(self)
-
     def add_task_queue(self, queue, exchange=None, exchange_type=None,
     def add_task_queue(self, queue, exchange=None, exchange_type=None,
             routing_key=None, **options):
             routing_key=None, **options):
         cset = self.task_consumer
         cset = self.task_consumer
@@ -416,10 +350,114 @@ class Consumer(object):
         message.reject_log_error(logger, self.connection_errors)
         message.reject_log_error(logger, self.connection_errors)
 
 
     def update_strategies(self):
     def update_strategies(self):
-        S = self.strategies
-        app = self.app
-        loader = app.loader
-        hostname = self.hostname
+        loader = self.app.loader
         for name, task in self.app.tasks.iteritems():
         for name, task in self.app.tasks.iteritems():
-            S[name] = task.start_strategy(app, self)
-            task.__trace__ = build_tracer(name, task, loader, hostname)
+            self.strategies[name] = task.start_strategy(self.app, self)
+            task.__trace__ = build_tracer(name, task, loader, self.hostname)
+
+
+class Connection(bootsteps.StartStopStep):
+    name = 'consumer.connection'
+
+    def __init__(self, c, **kwargs):
+        c.connection = None
+
+    def start(self, c):
+        c.connection = c.connect()
+        info('Connected to %s', c.connection.as_uri())
+
+    def shutdown(self, c):
+        # We must set self.connection to None here, so
+        # that the green pidbox thread exits.
+        connection, c.connection = c.connection, None
+        if connection:
+            ignore_errors(connection, connection.close)
+
+
+class Events(bootsteps.StartStopStep):
+    name = 'consumer.events'
+    requires = (Connection, )
+
+    def __init__(self, c, send_events=None, **kwargs):
+        self.send_events = send_events
+        c.event_dispatcher = None
+
+    def start(self, c):
+        # Flush events sent while connection was down.
+        prev = c.event_dispatcher
+        dis = c.event_dispatcher = c.app.events.Dispatcher(
+            c.connection, hostname=c.hostname, enabled=self.send_events,
+        )
+        if prev:
+            dis.copy_buffer(prev)
+            dis.flush()
+
+    def stop(self, c):
+        if c.event_dispatcher:
+            ignore_errors(c, c.event_dispatcher.close)
+            c.event_dispatcher = None
+    shutdown = stop
+
+
+class Heart(bootsteps.StartStopStep):
+    name = 'consumer.heart'
+    requires = (Events, )
+
+    def __init__(self, c, **kwargs):
+        c.heart = None
+
+    def start(self, c):
+        c.heart = heartbeat.Heart(c.timer, c.event_dispatcher)
+        c.heart.start()
+
+    def stop(self, c):
+        c.heart = c.heart and c.heart.stop()
+    shutdown = stop
+
+
+class Control(bootsteps.StartStopStep):
+    name = 'consumer.control'
+    requires = (Events, )
+
+    def __init__(self, c, **kwargs):
+        self.is_green = c.pool is not None and c.pool.is_green
+        self.box = (pidbox.gPidbox if self.is_green else pidbox.Pidbox)(c)
+        self.start = self.box.start
+        self.stop = self.box.stop
+        self.shutdown = self.box.shutdown
+
+
+class Tasks(bootsteps.StartStopStep):
+    name = 'consumer.tasks'
+    requires = (Control, )
+
+    def __init__(self, c, initial_prefetch_count=2, **kwargs):
+        c.task_consumer = c.qos = None
+        self.initial_prefetch_count = initial_prefetch_count
+
+    def start(self, c):
+        c.task_consumer = c.app.amqp.TaskConsumer(
+            c.connection, on_decode_error=c.on_decode_error,
+        )
+        c.qos = QoS(c.task_consumer, self.initial_prefetch_count)
+        c.qos.update()  # set initial prefetch count
+
+    def stop(self, c):
+        if c.task_consumer:
+            debug('Cancelling task consumer...')
+            ignore_errors(c, c.task_consumer.cancel)
+
+    def shutdown(self, c):
+        if c.task_consumer:
+            self.stop(c)
+            debug('Closing consumer channel...')
+            ignore_errors(c, c.task_consumer.close)
+            c.task_consumer = None
+
+
+class Evloop(bootsteps.StartStopStep):
+    name = 'consumer.evloop'
+    last = True
+
+    def start(self, c):
+        c.loop(*c.loop_args())

+ 0 - 212
celery/worker/consumer/components.py

@@ -1,212 +0,0 @@
-from __future__ import absolute_import
-
-import socket
-import threading
-
-from kombu.common import QoS
-
-from celery.datastructures import AttributeDict
-from celery.utils.log import get_logger
-
-from celery.worker.bootsteps import StartStopComponent
-from celery.worker.control import Panel
-from celery.worker.heartbeat import Heart
-
-logger = get_logger(__name__)
-info, error, debug = logger.info, logger.error, logger.debug
-
-
-class ConsumerConnection(StartStopComponent):
-    name = 'consumer.connection'
-    delay_shutdown = True
-
-    def __init__(self, c, **kwargs):
-        c.connection = None
-
-    def start(self, c):
-        debug('Re-establishing connection to the broker...')
-        c.connection = c._open_connection()
-        # Re-establish the broker connection and setup the task consumer.
-        info('consumer: Connected to %s.', c.connection.as_uri())
-
-    def stop(self, c):
-        pass
-
-    def shutdown(self, c):
-        # We must set self.connection to None here, so
-        # that the green pidbox thread exits.
-        connection, c.connection = c.connection, None
-
-        if connection:
-            c.maybe_conn_error(connection.close)
-
-
-class Events(StartStopComponent):
-    name = 'consumer.events'
-    requires = (ConsumerConnection, )
-
-    def __init__(self, c, send_events=None, **kwargs):
-        self.app = c.app
-        c.event_dispatcher = None
-        self.send_events = send_events
-
-    def start(self, c):
-        # Flush events sent while connection was down.
-        prev_event_dispatcher = c.event_dispatcher
-        c.event_dispatcher = self.app.events.Dispatcher(c.connection,
-                                                hostname=c.hostname,
-                                                enabled=self.send_events)
-        if prev_event_dispatcher:
-            c.event_dispatcher.copy_buffer(prev_event_dispatcher)
-            c.event_dispatcher.flush()
-
-    def stop(self, c):
-        if c.event_dispatcher:
-            debug('Shutting down event dispatcher...')
-            c.event_dispatcher = \
-                    c.maybe_conn_error(c.event_dispatcher.close)
-
-    def shutdown(self, c):
-        self.stop(c)
-
-
-class Heartbeat(StartStopComponent):
-    name = 'consumer.heart'
-    requires = (Events, )
-
-    def __init__(self, c, **kwargs):
-        c.heart = None
-
-    def start(self, c):
-        c.heart = Heart(c.timer, c.event_dispatcher)
-        c.heart.start()
-
-    def stop(self, c):
-        if c.heart:
-            # Stop the heartbeat thread if it's running.
-            debug('Heart: Going into cardiac arrest...')
-            c.heart = c.heart.stop()
-
-    def shutdown(self, c):
-        self.stop(c)
-
-
-class Controller(StartStopComponent):
-    name = 'consumer.controller'
-    requires = (Events, )
-
-    _pidbox_node_shutdown = None
-    _pidbox_node_stopped = None
-
-    def __init__(self, c, **kwargs):
-        self.app = c.app
-        pidbox_state = AttributeDict(
-            app=c.app, hostname=c.hostname, consumer=c,
-        )
-        self.pidbox_node = self.app.control.mailbox.Node(
-            c.hostname, state=pidbox_state, handlers=Panel.data,
-        )
-        self.broadcast_consumer = None
-        self.consumer = c
-
-    def start(self, c):
-        self.reset_pidbox_node()
-
-    def stop(self, c):
-        pass
-
-    def shutdown(self, c):
-        self.stop_pidbox_node()
-        if self.broadcast_consumer:
-            debug('Cancelling broadcast consumer...')
-            c.maybe_conn_error(self.broadcast_consumer.cancel)
-            self.broadcast_consumer = None
-
-    def on_control(self, body, message):
-        """Process remote control command message."""
-        try:
-            self.pidbox_node.handle_message(body, message)
-        except KeyError as exc:
-            error('No such control command: %s', exc)
-        except Exception as exc:
-            error('Control command error: %r', exc, exc_info=True)
-            self.reset_pidbox_node()
-
-    def reset_pidbox_node(self):
-        """Sets up the process mailbox."""
-        c = self.consumer
-        self.stop_pidbox_node()
-        # close previously opened channel if any.
-        if self.pidbox_node and self.pidbox_node.channel:
-            c.maybe_conn_error(self.pidbox_node.channel.close)
-
-        if c.pool is not None and c.pool.is_green:
-            return c.pool.spawn_n(self._green_pidbox_node)
-        self.pidbox_node.channel = c.connection.channel()
-        self.broadcast_consumer = self.pidbox_node.listen(
-                                        callback=self.on_control)
-
-    def stop_pidbox_node(self):
-        c = self.consumer
-        if self._pidbox_node_stopped:
-            self._pidbox_node_shutdown.set()
-            debug('Waiting for broadcast thread to shutdown...')
-            self._pidbox_node_stopped.wait()
-            self._pidbox_node_stopped = self._pidbox_node_shutdown = None
-        elif self.broadcast_consumer:
-            debug('Closing broadcast channel...')
-            self.broadcast_consumer = \
-                c.maybe_conn_error(self.broadcast_consumer.channel.close)
-
-    def _green_pidbox_node(self):
-        """Sets up the process mailbox when running in a greenlet
-        environment."""
-        # THIS CODE IS TERRIBLE
-        # Luckily work has already started rewriting the Consumer for 4.0.
-        self._pidbox_node_shutdown = threading.Event()
-        self._pidbox_node_stopped = threading.Event()
-        c = self.consumer
-        try:
-            with c._open_connection() as conn:
-                info('pidbox: Connected to %s.', conn.as_uri())
-                self.pidbox_node.channel = conn.default_channel
-                self.broadcast_consumer = self.pidbox_node.listen(
-                                            callback=self.on_control)
-                with self.broadcast_consumer:
-                    while not self._pidbox_node_shutdown.isSet():
-                        try:
-                            conn.drain_events(timeout=1.0)
-                        except socket.timeout:
-                            pass
-        finally:
-            self._pidbox_node_stopped.set()
-
-
-class Tasks(StartStopComponent):
-    name = 'consumer.tasks'
-    requires = (Controller, )
-    last = True
-
-    def __init__(self, c, initial_prefetch_count=2, **kwargs):
-        self.app = c.app
-        c.task_consumer = None
-        c.qos = None
-        self.initial_prefetch_count = initial_prefetch_count
-
-    def start(self, c):
-        c.task_consumer = self.app.amqp.TaskConsumer(c.connection,
-                                    on_decode_error=c.on_decode_error)
-        # QoS: Reset prefetch window.
-        c.qos = QoS(c.task_consumer, self.initial_prefetch_count)
-        c.qos.update()
-
-    def stop(self, c):
-        pass
-
-    def shutdown(self, c):
-        if c.task_consumer:
-            debug('Cancelling task consumer...')
-            c.maybe_conn_error(c.task_consumer.cancel)
-            debug('Closing consumer channel...')
-            c.maybe_conn_error(c.task_consumer.close)
-            c.task_consumer = None

+ 2 - 2
celery/worker/hub.py

@@ -132,11 +132,11 @@ class Hub(object):
         self.on_task = []
         self.on_task = []
 
 
     def start(self):
     def start(self):
-        """Called by StartStopComponent at worker startup."""
+        """Called by Hub bootstep at worker startup."""
         self.poller = eventio.poll()
         self.poller = eventio.poll()
 
 
     def stop(self):
     def stop(self):
-        """Called by StartStopComponent at worker shutdown."""
+        """Called by Hub bootstep at worker shutdown."""
         self.poller.close()
         self.poller.close()
 
 
     def init(self):
     def init(self):

+ 11 - 16
celery/worker/consumer/loops.py → celery/worker/loops.py

@@ -1,8 +1,8 @@
 """
 """
-celery.worker.consumer.loop
-~~~~~~~~~~~~~~~~~~~~~~~~~~~
+celery.worker.loop
+~~~~~~~~~~~~~~~~~~
 
 
-Worker eventloop.
+The consumers highly-optimized inner loop.
 
 
 """
 """
 from __future__ import absolute_import
 from __future__ import absolute_import
@@ -15,18 +15,14 @@ from Queue import Empty
 from kombu.utils.eventio import READ, WRITE, ERR
 from kombu.utils.eventio import READ, WRITE, ERR
 
 
 from celery.exceptions import InvalidTaskError, SystemTerminate
 from celery.exceptions import InvalidTaskError, SystemTerminate
-from celery.utils.log import get_logger
 from celery.worker import state
 from celery.worker import state
 from celery.worker.bootsteps import CLOSE
 from celery.worker.bootsteps import CLOSE
 
 
-logger = get_logger(__name__)
-debug = logger.debug
-
 #: Heartbeat check is called every heartbeat_seconds' / rate'.
 #: Heartbeat check is called every heartbeat_seconds' / rate'.
 AMQHEARTBEAT_RATE = 2.0
 AMQHEARTBEAT_RATE = 2.0
 
 
 
 
-def asynloop(obj, connection, consumer, strategies, namespace, hub, qos,
+def asynloop(obj, connection, consumer, strategies, ns, hub, qos,
         heartbeat, handle_unknown_message, handle_unknown_task,
         heartbeat, handle_unknown_message, handle_unknown_task,
         handle_invalid_task, sleep=sleep, min=min, Empty=Empty,
         handle_invalid_task, sleep=sleep, min=min, Empty=Empty,
         hbrate=AMQHEARTBEAT_RATE):
         hbrate=AMQHEARTBEAT_RATE):
@@ -67,10 +63,9 @@ def asynloop(obj, connection, consumer, strategies, namespace, hub, qos,
 
 
         consumer.callbacks = [on_task_received]
         consumer.callbacks = [on_task_received]
         consumer.consume()
         consumer.consume()
+        obj.on_ready()
 
 
-        debug('Ready to accept tasks!')
-
-        while namespace.state != CLOSE and obj.connection:
+        while ns.state != CLOSE and obj.connection:
             # shutdown if signal handlers told us to.
             # shutdown if signal handlers told us to.
             if state.should_stop:
             if state.should_stop:
                 raise SystemExit()
                 raise SystemExit()
@@ -112,7 +107,7 @@ def asynloop(obj, connection, consumer, strategies, namespace, hub, qos,
                         except (KeyError, Empty):
                         except (KeyError, Empty):
                             continue
                             continue
                         except socket.error:
                         except socket.error:
-                            if namespace.state != CLOSE:  # pragma: no cover
+                            if ns.state != CLOSE:  # pragma: no cover
                                 raise
                                 raise
                     if keep_draining:
                     if keep_draining:
                         drain_nowait()
                         drain_nowait()
@@ -124,7 +119,7 @@ def asynloop(obj, connection, consumer, strategies, namespace, hub, qos,
                 sleep(min(poll_timeout, 0.1))
                 sleep(min(poll_timeout, 0.1))
 
 
 
 
-def synloop(obj, connection, consumer, strategies, namespace, hub, qos,
+def synloop(obj, connection, consumer, strategies, ns, hub, qos,
         heartbeat, handle_unknown_message, handle_unknown_task,
         heartbeat, handle_unknown_message, handle_unknown_task,
         handle_invalid_task, **kwargs):
         handle_invalid_task, **kwargs):
     """Fallback blocking eventloop for transports that doesn't support AIO."""
     """Fallback blocking eventloop for transports that doesn't support AIO."""
@@ -145,9 +140,9 @@ def synloop(obj, connection, consumer, strategies, namespace, hub, qos,
     consumer.register_callback(on_task_received)
     consumer.register_callback(on_task_received)
     consumer.consume()
     consumer.consume()
 
 
-    debug('Ready to accept tasks!')
+    obj.on_ready()
 
 
-    while namespace.state != CLOSE and obj.connection:
+    while ns.state != CLOSE and obj.connection:
         state.maybe_shutdown()
         state.maybe_shutdown()
         if qos.prev != qos.value:         # pragma: no cover
         if qos.prev != qos.value:         # pragma: no cover
             qos.update()
             qos.update()
@@ -156,5 +151,5 @@ def synloop(obj, connection, consumer, strategies, namespace, hub, qos,
         except socket.timeout:
         except socket.timeout:
             pass
             pass
         except socket.error:
         except socket.error:
-            if namespace.state != CLOSE:  # pragma: no cover
+            if ns.state != CLOSE:  # pragma: no cover
                 raise
                 raise

+ 2 - 2
celery/worker/mediator.py

@@ -23,12 +23,12 @@ from celery.app import app_or_default
 from celery.utils.threads import bgThread
 from celery.utils.threads import bgThread
 from celery.utils.log import get_logger
 from celery.utils.log import get_logger
 
 
-from .bootsteps import StartStopComponent
+from .bootsteps import StartStopStep
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
-class WorkerComponent(StartStopComponent):
+class WorkerComponent(StartStopStep):
     name = 'worker.mediator'
     name = 'worker.mediator'
     requires = ('pool', 'queues', )
     requires = ('pool', 'queues', )
 
 

+ 103 - 0
celery/worker/pidbox.py

@@ -0,0 +1,103 @@
+from __future__ import absolute_import
+
+import socket
+import threading
+
+from kombu.common import ignore_errors
+
+from celery.datastructures import AttributeDict
+from celery.utils.log import get_logger
+
+from . import control
+
+logger = get_logger(__name__)
+debug, error, info = logger.debug, logger.error, logger.info
+
+
+class Pidbox(object):
+    consumer = None
+
+    def __init__(self, c):
+        self.c = c
+        self.hostname = c.hostname
+        self.node = c.app.control.mailbox.Node(c.hostname,
+            handlers=control.Panel.data,
+            state=AttributeDict(app=c.app, hostname=c.hostname, consumer=c),
+        )
+
+    def on_message(self, body, message):
+        try:
+            self.node.handle_message(body, message)
+        except KeyError as exc:
+            error('No such control command: %s', exc)
+        except Exception as exc:
+            error('Control command error: %r', exc, exc_info=True)
+            self.reset()
+
+    def start(self, c):
+        self.node.channel = c.connection.channel()
+        self.consumer = self.node.listen(callback=self.on_message)
+
+    def stop(self, c):
+        self.consumer = self._close_channel(c)
+
+    def reset(self):
+        """Sets up the process mailbox."""
+        self.stop(self.c)
+        self.start(self.c)
+
+    def _close_channel(self, c):
+        if self.node and self.node.channel:
+            ignore_errors(c, self.node.channel.close)
+
+    def shutdown(self, c):
+        if self.consumer:
+            debug('Cancelling broadcast consumer...')
+            ignore_errors(c, self.consumer.cancel)
+        self.stop(self.c)
+
+
+class gPidbox(Pidbox):
+    _node_shutdown = None
+    _node_stopped = None
+    _resets = 0
+
+    def start(self, c):
+        c.pool.spawn_n(self.loop, c)
+
+    def stop(self, c):
+        if self._node_stopped:
+            self._node_shutdown.set()
+            debug('Waiting for broadcast thread to shutdown...')
+            self._node_stopped.wait()
+            self._node_stopped = self._node_shutdown = None
+        super(gPidbox, self).stop(c)
+
+    def reset(self):
+        self._resets += 1
+
+    def _do_reset(self, c, connection):
+        self._close_channel(c)
+        self.node.channel = connection.channel()
+        self.consumer = self.node.listen(callback=self.on_message)
+        self.consumer.consume()
+
+    def loop(self, c):
+        resets = [self._resets]
+        shutdown = self._node_shutdown = threading.Event()
+        stopped = self._node_stopped = threading.Event()
+        try:
+            with c.connect() as connection:
+
+                info('pidbox: Connected to %s.', connection.as_uri())
+                self._do_reset(c, connection)
+                while not shutdown.is_set() and c.connection:
+                    if resets[0] < self._resets:
+                        resets[0] += 1
+                        self._do_reset(c, connection)
+                    try:
+                        connection.drain_events(timeout=1.0)
+                    except socket.timeout:
+                        pass
+        finally:
+            stopped.set()