Преглед на файлове

[consumer bootsteps] Tests passing

Ask Solem преди 12 години
родител
ревизия
ec2e3c36ba

+ 5 - 5
celery/app/defaults.py

@@ -150,8 +150,8 @@ NAMESPACES = {
         'WORKER_DIRECT': Option(False, type='bool'),
     },
     'CELERYD': {
-        'AUTOSCALER': Option('celery.worker.autoscale.Autoscaler'),
-        'AUTORELOADER': Option('celery.worker.autoreload.Autoreloader'),
+        'AUTOSCALER': Option('celery.worker.autoscale:Autoscaler'),
+        'AUTORELOADER': Option('celery.worker.autoreload:Autoreloader'),
         'BOOT_STEPS': Option((), type='tuple'),
         'CONSUMER_BOOT_STEPS': Option((), type='tuple'),
         'CONCURRENCY': Option(0, type='int'),
@@ -159,14 +159,14 @@ NAMESPACES = {
         'TIMER_PRECISION': Option(1.0, type='float'),
         'FORCE_EXECV': Option(True, type='bool'),
         'HIJACK_ROOT_LOGGER': Option(True, type='bool'),
-        'CONSUMER': Option(type='string'),
+        'CONSUMER': Option('celery.worker.consumer:Consumer', type='string'),
         'LOG_FORMAT': Option(DEFAULT_PROCESS_LOG_FMT),
         'LOG_COLOR': Option(type='bool'),
         'LOG_LEVEL': Option('WARN', deprecate_by='2.4', remove_by='4.0',
                             alt='--loglevel argument'),
         'LOG_FILE': Option(deprecate_by='2.4', remove_by='4.0',
                             alt='--logfile argument'),
-        'MEDIATOR': Option('celery.worker.mediator.Mediator'),
+        'MEDIATOR': Option('celery.worker.mediator:Mediator'),
         'MAX_TASKS_PER_CHILD': Option(type='int'),
         'POOL': Option(DEFAULT_POOL),
         'POOL_PUTLOCKS': Option(True, type='bool'),
@@ -180,7 +180,7 @@ NAMESPACES = {
     },
     'CELERYBEAT': {
         'SCHEDULE': Option({}, type='dict'),
-        'SCHEDULER': Option('celery.beat.PersistentScheduler'),
+        'SCHEDULER': Option('celery.beat:PersistentScheduler'),
         'SCHEDULE_FILENAME': Option('celerybeat-schedule'),
         'MAX_LOOP_INTERVAL': Option(0, type='float'),
         'LOG_LEVEL': Option('INFO', deprecate_by='2.4', remove_by='4.0',

+ 8 - 3
celery/apps/worker.py

@@ -101,6 +101,10 @@ class Worker(WorkController):
             enabled=not no_color if no_color is not None else no_color
         )
 
+    def on_init_namespace(self):
+        print('SETUP LOGGING: %r' % (self.redirect_stdouts, ))
+        self.setup_logging()
+
     def on_start(self):
         WorkController.on_start(self)
 
@@ -122,10 +126,11 @@ class Worker(WorkController):
 
         # Dump configuration to screen so we have some basic information
         # for when users sends bug reports.
-        print(str(self.colored.cyan(' \n', self.startup_info())) +
-              str(self.colored.reset(self.extra_info() or '')))
+        sys.__stdout__.write(
+            str(self.colored.cyan(' \n', self.startup_info())) +
+            str(self.colored.reset(self.extra_info() or '')) + '\n'
+        )
         self.set_process_status('-active-')
-        self.setup_logging()
         self.install_platform_tweaks(self)
 
     def on_consumer_ready(self, consumer):

+ 3 - 8
celery/tests/bin/test_celeryd.py

@@ -60,10 +60,7 @@ def disable_stdouts(fun):
 
 
 class Worker(cd.Worker):
-
-    def __init__(self, *args, **kwargs):
-        super(Worker, self).__init__(*args, **kwargs)
-        self.redirect_stdouts = False
+    redirect_stdouts = False
 
     def start(self, *args, **kwargs):
         self.on_start()
@@ -292,9 +289,7 @@ class test_Worker(WorkerAppCase):
 
     @disable_stdouts
     def test_redirect_stdouts(self):
-        worker = self.Worker()
-        worker.redirect_stdouts = False
-        worker.setup_logging()
+        worker = self.Worker(redirect_stdouts=False)
         with self.assertRaises(AttributeError):
             sys.stdout.logger
 
@@ -306,7 +301,7 @@ class test_Worker(WorkerAppCase):
             logging_setup[0] = True
 
         try:
-            worker = self.Worker()
+            worker = self.Worker(redirect_stdouts=False)
             worker.app.log.__class__._setup = False
             worker.setup_logging()
             self.assertTrue(logging_setup[0])

+ 123 - 84
celery/tests/worker/test_worker.py

@@ -28,7 +28,8 @@ from celery.worker.components import Queues, Timers, EvLoop, Pool
 from celery.worker.bootsteps import RUN, CLOSE, TERMINATE, StartStopComponent
 from celery.worker.buckets import FastQueue
 from celery.worker.job import Request
-from celery.worker.consumer import BlockingConsumer
+from celery.worker.consumer import Consumer
+from celery.worker.consumer import components as consumer_components
 from celery.utils.serialization import pickle
 from celery.utils.timer2 import Timer
 
@@ -39,7 +40,14 @@ class PlaceHolder(object):
         pass
 
 
-class MyKombuConsumer(BlockingConsumer):
+def find_component(obj, typ):
+    for c in obj.namespace.boot_steps:
+        if isinstance(c, typ):
+            return c
+    raise Exception('Instance %s has no %s component' % (obj, typ))
+
+
+class MyKombuConsumer(Consumer):
     broadcast_consumer = Mock()
     task_consumer = Mock()
 
@@ -227,45 +235,52 @@ class test_Consumer(Case):
 
     def test_start_when_closed(self):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
-        l._state = CLOSE
+        l.namespace.state = CLOSE
         l.start()
 
     def test_connection(self):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
 
-        l.reset_connection()
+        l.namespace.start(l)
         self.assertIsInstance(l.connection, Connection)
 
-        l._state = RUN
+        l.namespace.state = RUN
         l.event_dispatcher = None
-        l.stop_consumers(close_connection=False)
+        l.restart()
         self.assertTrue(l.connection)
 
-        l._state = RUN
-        l.stop_consumers()
+        l.namespace.state = RUN
+        l.shutdown()
         self.assertIsNone(l.connection)
         self.assertIsNone(l.task_consumer)
 
-        l.reset_connection()
+        l.namespace.start(l)
         self.assertIsInstance(l.connection, Connection)
-        l.stop_consumers()
+        l.restart()
 
         l.stop()
-        l.close_connection()
+        l.shutdown()
         self.assertIsNone(l.connection)
         self.assertIsNone(l.task_consumer)
 
     def test_close_connection(self):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
-        l._state = RUN
-        l.close_connection()
+        l.namespace.state = RUN
+        comp = find_component(l, consumer_components.ConsumerConnection)
+        conn = l.connection = Mock()
+        comp.shutdown(l)
+        self.assertTrue(conn.close.called)
+        self.assertIsNone(l.connection)
 
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         eventer = l.event_dispatcher = Mock()
         eventer.enabled = True
         heart = l.heart = MockHeart()
-        l._state = RUN
-        l.stop_consumers()
+        l.namespace.state = RUN
+        Events = find_component(l, consumer_components.Events)
+        Events.shutdown(l)
+        Heart = find_component(l, consumer_components.Heartbeat)
+        Heart.shutdown(l)
         self.assertTrue(eventer.close.call_count)
         self.assertTrue(heart.closed)
 
@@ -277,10 +292,11 @@ class test_Consumer(Case):
         l.event_dispatcher = Mock()
         l.pidbox_node = MockNode()
 
-        l.receive_message(m.decode(), m)
+        callback = self._get_on_message(l)
+        callback(m.decode(), m)
         self.assertTrue(warn.call_count)
 
-    @patch('celery.utils.timer2.to_timestamp')
+    @patch('celery.worker.consumer.to_timestamp')
     def test_receive_message_eta_OverflowError(self, to_timestamp):
         to_timestamp.side_effect = OverflowError()
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
@@ -292,7 +308,8 @@ class test_Consumer(Case):
         l.pidbox_node = MockNode()
         l.update_strategies()
 
-        l.receive_message(m.decode(), m)
+        callback = self._get_on_message(l)
+        callback(m.decode(), m)
         self.assertTrue(m.acknowledged)
         self.assertTrue(to_timestamp.call_count)
 
@@ -303,9 +320,9 @@ class test_Consumer(Case):
                            args=(1, 2), kwargs='foobarbaz', id=1)
         l.update_strategies()
         l.event_dispatcher = Mock()
-        l.pidbox_node = MockNode()
 
-        l.receive_message(m.decode(), m)
+        callback = self._get_on_message(l)
+        callback(m.decode(), m)
         self.assertIn('Received invalid task message', error.call_args[0][0])
 
     @patch('celery.worker.consumer.crit')
@@ -322,14 +339,25 @@ class test_Consumer(Case):
         self.assertTrue(message.ack.call_count)
         self.assertIn("Can't decode message body", crit.call_args[0][0])
 
+    def _get_on_message(self, l):
+        l.qos = Mock()
+        l.event_dispatcher = Mock()
+        l.task_consumer = Mock()
+        l.connection = Mock()
+        l.connection.drain_events.side_effect = SystemExit()
+
+        with self.assertRaises(SystemExit):
+            l.loop(*l.loop_args())
+        self.assertTrue(l.task_consumer.register_callback.called)
+        return l.task_consumer.register_callback.call_args[0][0]
+
     def test_receieve_message(self):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         m = create_message(Mock(), task=foo_task.name,
                            args=[2, 4, 8], kwargs={})
         l.update_strategies()
-
-        l.event_dispatcher = Mock()
-        l.receive_message(m.decode(), m)
+        callback = self._get_on_message(l)
+        callback(m.decode(), m)
 
         in_bucket = self.ready_queue.get_nowait()
         self.assertIsInstance(in_bucket, Request)
@@ -339,10 +367,10 @@ class test_Consumer(Case):
 
     def test_start_connection_error(self):
 
-        class MockConsumer(BlockingConsumer):
+        class MockConsumer(Consumer):
             iterations = 0
 
-            def consume_messages(self):
+            def loop(self, *args, **kwargs):
                 if not self.iterations:
                     self.iterations = 1
                     raise KeyError('foo')
@@ -360,10 +388,10 @@ class test_Consumer(Case):
         # Regression test for AMQPChannelExceptions that can occur within the
         # consumer. (i.e. 404 errors)
 
-        class MockConsumer(BlockingConsumer):
+        class MockConsumer(Consumer):
             iterations = 0
 
-            def consume_messages(self):
+            def loop(self, *args, **kwargs):
                 if not self.iterations:
                     self.iterations = 1
                     raise KeyError('foo')
@@ -377,7 +405,7 @@ class test_Consumer(Case):
         l.heart.stop()
         l.timer.stop()
 
-    def test_consume_messages_ignores_socket_timeout(self):
+    def test_loop_ignores_socket_timeout(self):
 
         class Connection(current_app.connection().__class__):
             obj = None
@@ -391,9 +419,9 @@ class test_Consumer(Case):
         l.task_consumer = Mock()
         l.connection.obj = l
         l.qos = QoS(l.task_consumer, 10)
-        l.consume_messages()
+        l.loop(*l.loop_args())
 
-    def test_consume_messages_when_socket_error(self):
+    def test_loop_when_socket_error(self):
 
         class Connection(current_app.connection().__class__):
             obj = None
@@ -403,19 +431,19 @@ class test_Consumer(Case):
                 raise socket.error('foo')
 
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
-        l._state = RUN
+        l.namespace.state = RUN
         c = l.connection = Connection()
         l.connection.obj = l
         l.task_consumer = Mock()
         l.qos = QoS(l.task_consumer, 10)
         with self.assertRaises(socket.error):
-            l.consume_messages()
+            l.loop(*l.loop_args())
 
-        l._state = CLOSE
+        l.namespace.state = CLOSE
         l.connection = c
-        l.consume_messages()
+        l.loop(*l.loop_args())
 
-    def test_consume_messages(self):
+    def test_loop(self):
 
         class Connection(current_app.connection().__class__):
             obj = None
@@ -429,8 +457,8 @@ class test_Consumer(Case):
         l.task_consumer = Mock()
         l.qos = QoS(l.task_consumer, 10)
 
-        l.consume_messages()
-        l.consume_messages()
+        l.loop(*l.loop_args())
+        l.loop(*l.loop_args())
         self.assertTrue(l.task_consumer.consume.call_count)
         l.task_consumer.qos.assert_called_with(prefetch_count=10)
         l.task_consumer.qos = Mock()
@@ -470,12 +498,13 @@ class test_Consumer(Case):
                            args=[2, 4, 8], kwargs={})
 
         l.task_consumer = Mock()
-        l.qos = QoS(l.task_consumer, l.initial_prefetch_count)
+        l.qos = QoS(l.task_consumer, 1)
         current_pcount = l.qos.value
         l.event_dispatcher = Mock()
         l.enabled = False
         l.update_strategies()
-        l.receive_message(m.decode(), m)
+        callback = self._get_on_message(l)
+        callback(m.decode(), m)
         l.timer.stop()
         l.timer.join(1)
 
@@ -490,22 +519,23 @@ class test_Consumer(Case):
 
     def test_on_control(self):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
-        l.pidbox_node = Mock()
-        l.reset_pidbox_node = Mock()
+        con = find_component(l, consumer_components.Controller)
+        con.pidbox_node = Mock()
+        con.reset_pidbox_node = Mock()
 
-        l.on_control('foo', 'bar')
-        l.pidbox_node.handle_message.assert_called_with('foo', 'bar')
+        con.on_control('foo', 'bar')
+        con.pidbox_node.handle_message.assert_called_with('foo', 'bar')
 
-        l.pidbox_node = Mock()
-        l.pidbox_node.handle_message.side_effect = KeyError('foo')
-        l.on_control('foo', 'bar')
-        l.pidbox_node.handle_message.assert_called_with('foo', 'bar')
+        con.pidbox_node = Mock()
+        con.pidbox_node.handle_message.side_effect = KeyError('foo')
+        con.on_control('foo', 'bar')
+        con.pidbox_node.handle_message.assert_called_with('foo', 'bar')
 
-        l.pidbox_node = Mock()
-        l.pidbox_node.handle_message.side_effect = ValueError('foo')
-        l.on_control('foo', 'bar')
-        l.pidbox_node.handle_message.assert_called_with('foo', 'bar')
-        l.reset_pidbox_node.assert_called_with()
+        con.pidbox_node = Mock()
+        con.pidbox_node.handle_message.side_effect = ValueError('foo')
+        con.on_control('foo', 'bar')
+        con.pidbox_node.handle_message.assert_called_with('foo', 'bar')
+        con.reset_pidbox_node.assert_called_with()
 
     def test_revoke(self):
         ready_queue = FastQueue()
@@ -517,7 +547,8 @@ class test_Consumer(Case):
         from celery.worker.state import revoked
         revoked.add(id)
 
-        l.receive_message(t.decode(), t)
+        callback = self._get_on_message(l)
+        callback(t.decode(), t)
         self.assertTrue(ready_queue.empty())
 
     def test_receieve_message_not_registered(self):
@@ -526,7 +557,8 @@ class test_Consumer(Case):
         m = create_message(backend, task='x.X.31x', args=[2, 4, 8], kwargs={})
 
         l.event_dispatcher = Mock()
-        self.assertFalse(l.receive_message(m.decode(), m))
+        callback = self._get_on_message(l)
+        self.assertFalse(callback(m.decode(), m))
         with self.assertRaises(Empty):
             self.ready_queue.get_nowait()
         self.assertTrue(self.timer.empty())
@@ -542,7 +574,8 @@ class test_Consumer(Case):
         l.connection_errors = (socket.error, )
         m.reject = Mock()
         m.reject.side_effect = socket.error('foo')
-        self.assertFalse(l.receive_message(m.decode(), m))
+        callback = self._get_on_message(l)
+        self.assertFalse(callback(m.decode(), m))
         self.assertTrue(warn.call_count)
         with self.assertRaises(Empty):
             self.ready_queue.get_nowait()
@@ -560,16 +593,17 @@ class test_Consumer(Case):
                            eta=(datetime.now() +
                                timedelta(days=1)).isoformat())
 
-        l.reset_connection()
+        l.namespace.start(l)
         p = l.app.conf.BROKER_CONNECTION_RETRY
         l.app.conf.BROKER_CONNECTION_RETRY = False
         try:
-            l.reset_connection()
+            l.namespace.start(l)
         finally:
             l.app.conf.BROKER_CONNECTION_RETRY = p
-        l.stop_consumers()
+        l.restart()
         l.event_dispatcher = Mock()
-        l.receive_message(m.decode(), m)
+        callback = self._get_on_message(l)
+        callback(m.decode(), m)
         l.timer.stop()
         in_hold = l.timer.queue[0]
         self.assertEqual(len(in_hold), 3)
@@ -583,24 +617,27 @@ class test_Consumer(Case):
 
     def test_reset_pidbox_node(self):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
-        l.pidbox_node = Mock()
-        chan = l.pidbox_node.channel = Mock()
+        con = find_component(l, consumer_components.Controller)
+        con.pidbox_node = Mock()
+        chan = con.pidbox_node.channel = Mock()
         l.connection = Mock()
         chan.close.side_effect = socket.error('foo')
         l.connection_errors = (socket.error, )
-        l.reset_pidbox_node()
+        con.reset_pidbox_node()
         chan.close.assert_called_with()
 
     def test_reset_pidbox_node_green(self):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        con = find_component(l, consumer_components.Controller)
         l.pool = Mock()
         l.pool.is_green = True
-        l.reset_pidbox_node()
-        l.pool.spawn_n.assert_called_with(l._green_pidbox_node)
+        con.reset_pidbox_node()
+        l.pool.spawn_n.assert_called_with(con._green_pidbox_node)
 
     def test__green_pidbox_node(self):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         l.pidbox_node = Mock()
+        cont = find_component(l, consumer_components.Controller)
 
         class BConsumer(Mock):
 
@@ -611,7 +648,7 @@ class test_Consumer(Case):
             def __exit__(self, *exc_info):
                 self.cancel()
 
-        l.pidbox_node.listen = BConsumer()
+        cont.pidbox_node.listen = BConsumer()
         connections = []
 
         class Connection(object):
@@ -640,18 +677,19 @@ class test_Consumer(Case):
                     self.calls += 1
                     raise socket.timeout()
                 self.obj.connection = None
-                self.obj._pidbox_node_shutdown.set()
+                cont._pidbox_node_shutdown.set()
 
             def close(self):
                 self.closed = True
 
         l.connection = Mock()
         l._open_connection = lambda: Connection(obj=l)
-        l._green_pidbox_node()
+        controller = find_component(l, consumer_components.Controller)
+        controller._green_pidbox_node()
 
-        l.pidbox_node.listen.assert_called_with(callback=l.on_control)
-        self.assertTrue(l.broadcast_consumer)
-        l.broadcast_consumer.consume.assert_called_with()
+        cont.pidbox_node.listen.assert_called_with(callback=cont.on_control)
+        self.assertTrue(cont.broadcast_consumer)
+        cont.broadcast_consumer.consume.assert_called_with()
 
         self.assertIsNone(l.connection)
         self.assertTrue(connections[0].closed)
@@ -673,12 +711,13 @@ class test_Consumer(Case):
 
     def test_stop_pidbox_node(self):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
-        l._pidbox_node_stopped = Event()
-        l._pidbox_node_shutdown = Event()
-        l._pidbox_node_stopped.set()
-        l.stop_pidbox_node()
+        cont = find_component(l, consumer_components.Controller)
+        cont._pidbox_node_stopped = Event()
+        cont._pidbox_node_shutdown = Event()
+        cont._pidbox_node_stopped.set()
+        cont.stop_pidbox_node()
 
-    def test_start__consume_messages(self):
+    def test_start__loop(self):
 
         class _QoS(object):
             prev = 3
@@ -703,18 +742,18 @@ class test_Consumer(Case):
         l.connection = Connection()
         l.iterations = 0
 
-        def raises_KeyError(limit=None):
+        def raises_KeyError(*args, **kwargs):
             l.iterations += 1
             if l.qos.prev != l.qos.value:
                 l.qos.update()
             if l.iterations >= 2:
                 raise KeyError('foo')
 
-        l.consume_messages = raises_KeyError
+        l.loop = raises_KeyError
         with self.assertRaises(KeyError):
             l.start()
         self.assertTrue(init_callback.call_count)
-        self.assertEqual(l.iterations, 1)
+        self.assertEqual(l.iterations, 2)
         self.assertEqual(l.qos.prev, l.qos.value)
 
         init_callback.reset_mock()
@@ -724,25 +763,25 @@ class test_Consumer(Case):
         l.task_consumer = Mock()
         l.broadcast_consumer = Mock()
         l.connection = Connection()
-        l.consume_messages = Mock(side_effect=socket.error('foo'))
+        l.loop = Mock(side_effect=socket.error('foo'))
         with self.assertRaises(socket.error):
             l.start()
         self.assertTrue(init_callback.call_count)
-        self.assertTrue(l.consume_messages.call_count)
+        self.assertTrue(l.loop.call_count)
 
     def test_reset_connection_with_no_node(self):
-        l = BlockingConsumer(self.ready_queue, timer=self.timer)
+        l = Consumer(self.ready_queue, timer=self.timer)
         self.assertEqual(None, l.pool)
-        l.reset_connection()
+        l.namespace.start(l)
 
     def test_on_task_revoked(self):
-        l = BlockingConsumer(self.ready_queue, timer=self.timer)
+        l = Consumer(self.ready_queue, timer=self.timer)
         task = Mock()
         task.revoked.return_value = True
         l.on_task(task)
 
     def test_on_task_no_events(self):
-        l = BlockingConsumer(self.ready_queue, timer=self.timer)
+        l = Consumer(self.ready_queue, timer=self.timer)
         task = Mock()
         task.revoked.return_value = False
         l.event_dispatcher = Mock()

+ 5 - 0
celery/worker/__init__.py

@@ -123,12 +123,17 @@ class WorkController(configurated):
         # Initialize boot steps
         self.pool_cls = _concurrency.get_implementation(self.pool_cls)
         self.components = []
+        self.on_init_namespace()
         self.namespace = Namespace(app=self.app,
                                    on_start=self.on_start,
                                    on_close=self.on_close,
                                    on_stopped=self.on_stopped)
         self.namespace.apply(self, **kwargs)
 
+
+    def on_init_namespace(self):
+        pass
+
     def on_before_init(self, **kwargs):
         pass
 

+ 21 - 169
celery/worker/consumer/__init__.py

@@ -13,15 +13,10 @@ from __future__ import absolute_import
 import logging
 import socket
 
-from time import sleep
-from Queue import Empty
-
 from kombu.syn import _detect_environment
 from kombu.utils.encoding import safe_repr
-from kombu.utils.eventio import READ, WRITE, ERR
 
 from celery.app import app_or_default
-from celery.exceptions import InvalidTaskError, SystemTerminate
 from celery.task.trace import build_tracer
 from celery.utils.timer2 import default_timer, to_timestamp
 from celery.utils.functional import noop
@@ -30,16 +25,16 @@ from celery.utils.log import get_logger
 from celery.utils.text import dump_body
 from celery.utils.timeutils import humanize_seconds
 from celery.worker import state
+from celery.worker.state import maybe_shutdown
 from celery.worker.bootsteps import Namespace as _NS, StartStopComponent, CLOSE
 
+from . import loops
+
 logger = get_logger(__name__)
 info, warn, error, crit = (logger.info, logger.warn,
                            logger.error, logger.critical)
 task_reserved = state.task_reserved
 
-#: Heartbeat check is called every heartbeat_seconds' / rate'.
-AMQHEARTBEAT_RATE = 2.0
-
 CONNECTION_RETRY = """\
 consumer: Connection to broker lost. \
 Trying to re-establish the connection...\
@@ -102,13 +97,9 @@ class Component(StartStopComponent):
     name = 'worker.consumer'
     last = True
 
-    def Consumer(self, w):
-        return (w.consumer_cls or
-                Consumer if w.hub else BlockingConsumer)
-
     def create(self, w):
         prefetch_count = w.concurrency * w.prefetch_multiplier
-        c = w.consumer = self.instantiate(self.Consumer(w),
+        c = w.consumer = self.instantiate(w.consumer_cls,
                 w.ready_queue,
                 hostname=w.hostname,
                 send_events=w.send_events,
@@ -191,6 +182,8 @@ class Consumer(object):
         self.channel_errors = conninfo.channel_errors
 
         self._does_info = logger.isEnabledFor(logging.INFO)
+        if not hasattr(self, 'loop'):
+            self.loop = loops.asynloop if hub else loops.synloop
         if hub:
             hub.on_init.append(self.on_poll_init)
         self.hub = hub
@@ -226,17 +219,24 @@ class Consumer(object):
         consuming messages.
 
         """
-        ns = self.namespace
+        ns, loop, loop_args = self.namespace, self.loop, self.loop_args()
         while ns.state != CLOSE:
-            self.maybe_shutdown()
+            maybe_shutdown()
             try:
-                self.namespace.start(self)
-                self.consume_messages()
+                ns.start(self)
+                loop(*loop_args)
             except self.connection_errors + self.channel_errors:
                 error(CONNECTION_RETRY, exc_info=True)
-                ns.restart(self)
-            ns.close(self)
-            ns.state = CLOSE
+                self.restart()
+
+    def loop_args(self):
+        return (self, self.connection, self.task_consumer,
+                self.strategies, self.namespace, self.hub, self.qos,
+                self.amqheartbeat, self.handle_unknown_message,
+                self.handle_unknown_task, self.handle_invalid_task)
+
+    def restart(self):
+        return self.namespace.restart(self)
 
     def on_poll_init(self, hub):
         hub.update_readers(self.connection.eventmap)
@@ -304,7 +304,7 @@ class Consumer(object):
 
         return conn.ensure_connection(_error_handler,
                     self.app.conf.BROKER_CONNECTION_MAX_RETRIES,
-                    callback=self.maybe_shutdown)
+                    callback=maybe_shutdown)
 
     def stop(self):
         """Stop consuming.
@@ -315,12 +315,6 @@ class Consumer(object):
         """
         self.namespace.stop(self)
 
-    def maybe_shutdown(self):
-        if state.should_stop:
-            raise SystemExit()
-        elif state.should_terminate:
-            raise SystemTerminate()
-
     def add_task_queue(self, queue, exchange=None, exchange_type=None,
             routing_key=None, **options):
         cset = self.task_consumer
@@ -358,106 +352,6 @@ class Consumer(object):
             conninfo.pop('password', None)  # don't send password.
         return {'broker': conninfo, 'prefetch_count': self.qos.value}
 
-    def consume_messages(self, sleep=sleep, min=min, Empty=Empty,
-            hbrate=AMQHEARTBEAT_RATE):
-        """Consume messages forever (or until an exception is raised)."""
-
-        with self.hub as hub:
-            ns = self.namespace
-            qos = self.qos
-            update_qos = qos.update
-            update_readers = hub.update_readers
-            readers, writers = hub.readers, hub.writers
-            poll = hub.poller.poll
-            fire_timers = hub.fire_timers
-            scheduled = hub.timer._queue
-            connection = self.connection
-            hb = self.amqheartbeat
-            hbtick = connection.heartbeat_check
-            on_poll_start = connection.transport.on_poll_start
-            on_poll_empty = connection.transport.on_poll_empty
-            strategies = self.strategies
-            drain_nowait = connection.drain_nowait
-            on_task_callbacks = hub.on_task
-            keep_draining = connection.transport.nb_keep_draining
-
-            if hb and connection.supports_heartbeats:
-                hub.timer.apply_interval(
-                    hb * 1000.0 / hbrate, hbtick, (hbrate, ))
-
-            def on_task_received(body, message):
-                if on_task_callbacks:
-                    [callback() for callback in on_task_callbacks]
-                try:
-                    name = body['task']
-                except (KeyError, TypeError):
-                    return self.handle_unknown_message(body, message)
-                try:
-                    strategies[name](message, body, message.ack_log_error)
-                except KeyError as exc:
-                    self.handle_unknown_task(body, message, exc)
-                except InvalidTaskError as exc:
-                    self.handle_invalid_task(body, message, exc)
-                #fire_timers()
-
-            self.task_consumer.callbacks = [on_task_received]
-            self.task_consumer.consume()
-
-            debug('Ready to accept tasks!')
-
-            while ns.state != CLOSE and self.connection:
-                # shutdown if signal handlers told us to.
-                if state.should_stop:
-                    raise SystemExit()
-                elif state.should_terminate:
-                    raise SystemTerminate()
-
-                # fire any ready timers, this also returns
-                # the number of seconds until we need to fire timers again.
-                poll_timeout = fire_timers() if scheduled else 1
-
-                # We only update QoS when there is no more messages to read.
-                # This groups together qos calls, and makes sure that remote
-                # control commands will be prioritized over task messages.
-                if qos.prev != qos.value:
-                    update_qos()
-
-                update_readers(on_poll_start())
-                if readers or writers:
-                    connection.more_to_read = True
-                    while connection.more_to_read:
-                        try:
-                            events = poll(poll_timeout)
-                        except ValueError:  # Issue 882
-                            return
-                        if not events:
-                            on_poll_empty()
-                        for fileno, event in events or ():
-                            try:
-                                if event & READ:
-                                    readers[fileno](fileno, event)
-                                if event & WRITE:
-                                    writers[fileno](fileno, event)
-                                if event & ERR:
-                                    for handlermap in readers, writers:
-                                        try:
-                                            handlermap[fileno](fileno, event)
-                                        except KeyError:
-                                            pass
-                            except (KeyError, Empty):
-                                continue
-                            except socket.error:
-                                if ns.state != CLOSE:  # pragma: no cover
-                                    raise
-                        if keep_draining:
-                            drain_nowait()
-                            poll_timeout = 0
-                        else:
-                            connection.more_to_read = False
-                else:
-                    # no sockets yet, startup is probably not done.
-                    sleep(min(poll_timeout, 0.1))
-
     def on_task(self, task, task_reserved=task_reserved):
         """Handle received task.
 
@@ -527,45 +421,3 @@ class Consumer(object):
         for name, task in self.app.tasks.iteritems():
             S[name] = task.start_strategy(app, self)
             task.__trace__ = build_tracer(name, task, loader, hostname)
-
-
-class BlockingConsumer(Consumer):
-
-    def consume_messages(self):
-        # receive_message handles incoming messages.
-        self.task_consumer.register_callback(self.receive_message)
-        self.task_consumer.consume()
-
-        debug('Ready to accept tasks!')
-        ns = self.namespace
-
-        while ns.state != CLOSE and self.connection:
-            self.maybe_shutdown()
-            if self.qos.prev != self.qos.value:     # pragma: no cover
-                self.qos.update()
-            try:
-                self.connection.drain_events(timeout=10.0)
-            except socket.timeout:
-                pass
-            except socket.error:
-                if ns.state != CLOSE:            # pragma: no cover
-                    raise
-
-    def receive_message(self, body, message):
-        """Handles incoming messages.
-
-        :param body: The message body.
-        :param message: The kombu message object.
-
-        """
-        try:
-            name = body['task']
-        except (KeyError, TypeError):
-            return self.handle_unknown_message(body, message)
-
-        try:
-            self.strategies[name](message, body, message.ack_log_error)
-        except KeyError as exc:
-            self.handle_unknown_task(body, message, exc)
-        except InvalidTaskError as exc:
-            self.handle_invalid_task(body, message, exc)

+ 4 - 5
celery/worker/consumer/components.py

@@ -67,7 +67,7 @@ class Events(StartStopComponent):
                     c.maybe_conn_error(c.event_dispatcher.close)
 
     def shutdown(self, c):
-        pass
+        self.stop(c)
 
 
 class Heartbeat(StartStopComponent):
@@ -88,7 +88,7 @@ class Heartbeat(StartStopComponent):
             c.heart = c.heart.stop()
 
     def shutdown(self, c):
-        pass
+        self.stop(c)
 
 
 class Controller(StartStopComponent):
@@ -100,7 +100,6 @@ class Controller(StartStopComponent):
 
     def __init__(self, c, **kwargs):
         self.app = c.app
-        self.pool = c.pool
         pidbox_state = AttributeDict(
             app=c.app, hostname=c.hostname, consumer=c,
         )
@@ -141,8 +140,8 @@ class Controller(StartStopComponent):
         if self.pidbox_node and self.pidbox_node.channel:
             c.maybe_conn_error(self.pidbox_node.channel.close)
 
-        if self.pool is not None and self.pool.is_green:
-            return self.pool.spawn_n(self._green_pidbox_node)
+        if c.pool is not None and c.pool.is_green:
+            return c.pool.spawn_n(self._green_pidbox_node)
         self.pidbox_node.channel = c.connection.channel()
         self.broadcast_consumer = self.pidbox_node.listen(
                                         callback=self.on_control)

+ 160 - 0
celery/worker/consumer/loops.py

@@ -0,0 +1,160 @@
+"""
+celery.worker.consumer.loop
+~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Worker eventloop.
+
+"""
+from __future__ import absolute_import
+
+import socket
+
+from time import sleep
+from Queue import Empty
+
+from kombu.utils.eventio import READ, WRITE, ERR
+
+from celery.exceptions import InvalidTaskError, SystemTerminate
+from celery.utils.log import get_logger
+from celery.worker import state
+from celery.worker.bootsteps import CLOSE
+
+logger = get_logger(__name__)
+debug = logger.debug
+
+#: Heartbeat check is called every heartbeat_seconds' / rate'.
+AMQHEARTBEAT_RATE = 2.0
+
+
+def asynloop(obj, connection, consumer, strategies, namespace, hub, qos,
+        heartbeat, handle_unknown_message, handle_unknown_task,
+        handle_invalid_task, sleep=sleep, min=min, Empty=Empty,
+        hbrate=AMQHEARTBEAT_RATE):
+    """Non-blocking eventloop consuming messages until connection is lost,
+    or shutdown is requested."""
+
+    with hub as hub:
+        update_qos = qos.update
+        update_readers = hub.update_readers
+        readers, writers = hub.readers, hub.writers
+        poll = hub.poller.poll
+        fire_timers = hub.fire_timers
+        scheduled = hub.timer._queue
+        hbtick = connection.heartbeat_check
+        on_poll_start = connection.transport.on_poll_start
+        on_poll_empty = connection.transport.on_poll_empty
+        drain_nowait = connection.drain_nowait
+        on_task_callbacks = hub.on_task
+        keep_draining = connection.transport.nb_keep_draining
+
+        if heartbeat and connection.supports_heartbeats:
+            hub.timer.apply_interval(
+                heartbeat * 1000.0 / hbrate, hbtick, (hbrate, ))
+
+        def on_task_received(body, message):
+            if on_task_callbacks:
+                [callback() for callback in on_task_callbacks]
+            try:
+                name = body['task']
+            except (KeyError, TypeError):
+                return handle_unknown_message(body, message)
+            try:
+                strategies[name](message, body, message.ack_log_error)
+            except KeyError as exc:
+                handle_unknown_task(body, message, exc)
+            except InvalidTaskError as exc:
+                handle_invalid_task(body, message, exc)
+
+        consumer.callbacks = [on_task_received]
+        consumer.consume()
+
+        debug('Ready to accept tasks!')
+
+        while namespace.state != CLOSE and obj.connection:
+            # shutdown if signal handlers told us to.
+            if state.should_stop:
+                raise SystemExit()
+            elif state.should_terminate:
+                raise SystemTerminate()
+
+            # fire any ready timers, this also returns
+            # the number of seconds until we need to fire timers again.
+            poll_timeout = fire_timers() if scheduled else 1
+
+            # We only update QoS when there is no more messages to read.
+            # This groups together qos calls, and makes sure that remote
+            # control commands will be prioritized over task messages.
+            if qos.prev != qos.value:
+                update_qos()
+
+            update_readers(on_poll_start())
+            if readers or writers:
+                connection.more_to_read = True
+                while connection.more_to_read:
+                    try:
+                        events = poll(poll_timeout)
+                    except ValueError:  # Issue 882
+                        return
+                    if not events:
+                        on_poll_empty()
+                    for fileno, event in events or ():
+                        try:
+                            if event & READ:
+                                readers[fileno](fileno, event)
+                            if event & WRITE:
+                                writers[fileno](fileno, event)
+                            if event & ERR:
+                                for handlermap in readers, writers:
+                                    try:
+                                        handlermap[fileno](fileno, event)
+                                    except KeyError:
+                                        pass
+                        except (KeyError, Empty):
+                            continue
+                        except socket.error:
+                            if namespace.state != CLOSE:  # pragma: no cover
+                                raise
+                    if keep_draining:
+                        drain_nowait()
+                        poll_timeout = 0
+                    else:
+                        connection.more_to_read = False
+            else:
+                # no sockets yet, startup is probably not done.
+                sleep(min(poll_timeout, 0.1))
+
+
+def synloop(obj, connection, consumer, strategies, namespace, hub, qos,
+        heartbeat, handle_unknown_message, handle_unknown_task,
+        handle_invalid_task, **kwargs):
+    """Fallback blocking eventloop for transports that doesn't support AIO."""
+
+    def on_task_received(body, message):
+        try:
+            name = body['task']
+        except (KeyError, TypeError):
+            return handle_unknown_message(body, message)
+
+        try:
+            strategies[name](message, body, message.ack_log_error)
+        except KeyError as exc:
+            handle_unknown_task(body, message, exc)
+        except InvalidTaskError as exc:
+            handle_invalid_task(body, message, exc)
+
+    consumer.register_callback(on_task_received)
+    consumer.consume()
+
+    debug('Ready to accept tasks!')
+
+    while namespace.state != CLOSE and obj.connection:
+        state.maybe_shutdown()
+        if qos.prev != qos.value:         # pragma: no cover
+            qos.update()
+        try:
+            connection.drain_events(timeout=2.0)
+        except socket.timeout:
+            pass
+        except socket.error:
+            if namespace.state != CLOSE:  # pragma: no cover
+                raise

+ 7 - 0
celery/worker/state.py

@@ -53,6 +53,13 @@ should_stop = False
 should_terminate = False
 
 
+def maybe_shutdown():
+    if should_stop:
+        raise SystemExit()
+    elif should_terminate:
+        raise SystemTerminate()
+
+
 def task_accepted(request):
     """Updates global state when a task has been accepted."""
     active_requests.add(request)