Browse Source

[consumer bootsteps] Tests passing

Ask Solem 12 years ago
parent
commit
ec2e3c36ba

+ 5 - 5
celery/app/defaults.py

@@ -150,8 +150,8 @@ NAMESPACES = {
         'WORKER_DIRECT': Option(False, type='bool'),
         'WORKER_DIRECT': Option(False, type='bool'),
     },
     },
     'CELERYD': {
     'CELERYD': {
-        'AUTOSCALER': Option('celery.worker.autoscale.Autoscaler'),
-        'AUTORELOADER': Option('celery.worker.autoreload.Autoreloader'),
+        'AUTOSCALER': Option('celery.worker.autoscale:Autoscaler'),
+        'AUTORELOADER': Option('celery.worker.autoreload:Autoreloader'),
         'BOOT_STEPS': Option((), type='tuple'),
         'BOOT_STEPS': Option((), type='tuple'),
         'CONSUMER_BOOT_STEPS': Option((), type='tuple'),
         'CONSUMER_BOOT_STEPS': Option((), type='tuple'),
         'CONCURRENCY': Option(0, type='int'),
         'CONCURRENCY': Option(0, type='int'),
@@ -159,14 +159,14 @@ NAMESPACES = {
         'TIMER_PRECISION': Option(1.0, type='float'),
         'TIMER_PRECISION': Option(1.0, type='float'),
         'FORCE_EXECV': Option(True, type='bool'),
         'FORCE_EXECV': Option(True, type='bool'),
         'HIJACK_ROOT_LOGGER': Option(True, type='bool'),
         'HIJACK_ROOT_LOGGER': Option(True, type='bool'),
-        'CONSUMER': Option(type='string'),
+        'CONSUMER': Option('celery.worker.consumer:Consumer', type='string'),
         'LOG_FORMAT': Option(DEFAULT_PROCESS_LOG_FMT),
         'LOG_FORMAT': Option(DEFAULT_PROCESS_LOG_FMT),
         'LOG_COLOR': Option(type='bool'),
         'LOG_COLOR': Option(type='bool'),
         'LOG_LEVEL': Option('WARN', deprecate_by='2.4', remove_by='4.0',
         'LOG_LEVEL': Option('WARN', deprecate_by='2.4', remove_by='4.0',
                             alt='--loglevel argument'),
                             alt='--loglevel argument'),
         'LOG_FILE': Option(deprecate_by='2.4', remove_by='4.0',
         'LOG_FILE': Option(deprecate_by='2.4', remove_by='4.0',
                             alt='--logfile argument'),
                             alt='--logfile argument'),
-        'MEDIATOR': Option('celery.worker.mediator.Mediator'),
+        'MEDIATOR': Option('celery.worker.mediator:Mediator'),
         'MAX_TASKS_PER_CHILD': Option(type='int'),
         'MAX_TASKS_PER_CHILD': Option(type='int'),
         'POOL': Option(DEFAULT_POOL),
         'POOL': Option(DEFAULT_POOL),
         'POOL_PUTLOCKS': Option(True, type='bool'),
         'POOL_PUTLOCKS': Option(True, type='bool'),
@@ -180,7 +180,7 @@ NAMESPACES = {
     },
     },
     'CELERYBEAT': {
     'CELERYBEAT': {
         'SCHEDULE': Option({}, type='dict'),
         'SCHEDULE': Option({}, type='dict'),
-        'SCHEDULER': Option('celery.beat.PersistentScheduler'),
+        'SCHEDULER': Option('celery.beat:PersistentScheduler'),
         'SCHEDULE_FILENAME': Option('celerybeat-schedule'),
         'SCHEDULE_FILENAME': Option('celerybeat-schedule'),
         'MAX_LOOP_INTERVAL': Option(0, type='float'),
         'MAX_LOOP_INTERVAL': Option(0, type='float'),
         'LOG_LEVEL': Option('INFO', deprecate_by='2.4', remove_by='4.0',
         'LOG_LEVEL': Option('INFO', deprecate_by='2.4', remove_by='4.0',

+ 8 - 3
celery/apps/worker.py

@@ -101,6 +101,10 @@ class Worker(WorkController):
             enabled=not no_color if no_color is not None else no_color
             enabled=not no_color if no_color is not None else no_color
         )
         )
 
 
+    def on_init_namespace(self):
+        print('SETUP LOGGING: %r' % (self.redirect_stdouts, ))
+        self.setup_logging()
+
     def on_start(self):
     def on_start(self):
         WorkController.on_start(self)
         WorkController.on_start(self)
 
 
@@ -122,10 +126,11 @@ class Worker(WorkController):
 
 
         # Dump configuration to screen so we have some basic information
         # Dump configuration to screen so we have some basic information
         # for when users sends bug reports.
         # for when users sends bug reports.
-        print(str(self.colored.cyan(' \n', self.startup_info())) +
-              str(self.colored.reset(self.extra_info() or '')))
+        sys.__stdout__.write(
+            str(self.colored.cyan(' \n', self.startup_info())) +
+            str(self.colored.reset(self.extra_info() or '')) + '\n'
+        )
         self.set_process_status('-active-')
         self.set_process_status('-active-')
-        self.setup_logging()
         self.install_platform_tweaks(self)
         self.install_platform_tweaks(self)
 
 
     def on_consumer_ready(self, consumer):
     def on_consumer_ready(self, consumer):

+ 3 - 8
celery/tests/bin/test_celeryd.py

@@ -60,10 +60,7 @@ def disable_stdouts(fun):
 
 
 
 
 class Worker(cd.Worker):
 class Worker(cd.Worker):
-
-    def __init__(self, *args, **kwargs):
-        super(Worker, self).__init__(*args, **kwargs)
-        self.redirect_stdouts = False
+    redirect_stdouts = False
 
 
     def start(self, *args, **kwargs):
     def start(self, *args, **kwargs):
         self.on_start()
         self.on_start()
@@ -292,9 +289,7 @@ class test_Worker(WorkerAppCase):
 
 
     @disable_stdouts
     @disable_stdouts
     def test_redirect_stdouts(self):
     def test_redirect_stdouts(self):
-        worker = self.Worker()
-        worker.redirect_stdouts = False
-        worker.setup_logging()
+        worker = self.Worker(redirect_stdouts=False)
         with self.assertRaises(AttributeError):
         with self.assertRaises(AttributeError):
             sys.stdout.logger
             sys.stdout.logger
 
 
@@ -306,7 +301,7 @@ class test_Worker(WorkerAppCase):
             logging_setup[0] = True
             logging_setup[0] = True
 
 
         try:
         try:
-            worker = self.Worker()
+            worker = self.Worker(redirect_stdouts=False)
             worker.app.log.__class__._setup = False
             worker.app.log.__class__._setup = False
             worker.setup_logging()
             worker.setup_logging()
             self.assertTrue(logging_setup[0])
             self.assertTrue(logging_setup[0])

+ 123 - 84
celery/tests/worker/test_worker.py

@@ -28,7 +28,8 @@ from celery.worker.components import Queues, Timers, EvLoop, Pool
 from celery.worker.bootsteps import RUN, CLOSE, TERMINATE, StartStopComponent
 from celery.worker.bootsteps import RUN, CLOSE, TERMINATE, StartStopComponent
 from celery.worker.buckets import FastQueue
 from celery.worker.buckets import FastQueue
 from celery.worker.job import Request
 from celery.worker.job import Request
-from celery.worker.consumer import BlockingConsumer
+from celery.worker.consumer import Consumer
+from celery.worker.consumer import components as consumer_components
 from celery.utils.serialization import pickle
 from celery.utils.serialization import pickle
 from celery.utils.timer2 import Timer
 from celery.utils.timer2 import Timer
 
 
@@ -39,7 +40,14 @@ class PlaceHolder(object):
         pass
         pass
 
 
 
 
-class MyKombuConsumer(BlockingConsumer):
+def find_component(obj, typ):
+    for c in obj.namespace.boot_steps:
+        if isinstance(c, typ):
+            return c
+    raise Exception('Instance %s has no %s component' % (obj, typ))
+
+
+class MyKombuConsumer(Consumer):
     broadcast_consumer = Mock()
     broadcast_consumer = Mock()
     task_consumer = Mock()
     task_consumer = Mock()
 
 
@@ -227,45 +235,52 @@ class test_Consumer(Case):
 
 
     def test_start_when_closed(self):
     def test_start_when_closed(self):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
-        l._state = CLOSE
+        l.namespace.state = CLOSE
         l.start()
         l.start()
 
 
     def test_connection(self):
     def test_connection(self):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
 
 
-        l.reset_connection()
+        l.namespace.start(l)
         self.assertIsInstance(l.connection, Connection)
         self.assertIsInstance(l.connection, Connection)
 
 
-        l._state = RUN
+        l.namespace.state = RUN
         l.event_dispatcher = None
         l.event_dispatcher = None
-        l.stop_consumers(close_connection=False)
+        l.restart()
         self.assertTrue(l.connection)
         self.assertTrue(l.connection)
 
 
-        l._state = RUN
-        l.stop_consumers()
+        l.namespace.state = RUN
+        l.shutdown()
         self.assertIsNone(l.connection)
         self.assertIsNone(l.connection)
         self.assertIsNone(l.task_consumer)
         self.assertIsNone(l.task_consumer)
 
 
-        l.reset_connection()
+        l.namespace.start(l)
         self.assertIsInstance(l.connection, Connection)
         self.assertIsInstance(l.connection, Connection)
-        l.stop_consumers()
+        l.restart()
 
 
         l.stop()
         l.stop()
-        l.close_connection()
+        l.shutdown()
         self.assertIsNone(l.connection)
         self.assertIsNone(l.connection)
         self.assertIsNone(l.task_consumer)
         self.assertIsNone(l.task_consumer)
 
 
     def test_close_connection(self):
     def test_close_connection(self):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
-        l._state = RUN
-        l.close_connection()
+        l.namespace.state = RUN
+        comp = find_component(l, consumer_components.ConsumerConnection)
+        conn = l.connection = Mock()
+        comp.shutdown(l)
+        self.assertTrue(conn.close.called)
+        self.assertIsNone(l.connection)
 
 
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         eventer = l.event_dispatcher = Mock()
         eventer = l.event_dispatcher = Mock()
         eventer.enabled = True
         eventer.enabled = True
         heart = l.heart = MockHeart()
         heart = l.heart = MockHeart()
-        l._state = RUN
-        l.stop_consumers()
+        l.namespace.state = RUN
+        Events = find_component(l, consumer_components.Events)
+        Events.shutdown(l)
+        Heart = find_component(l, consumer_components.Heartbeat)
+        Heart.shutdown(l)
         self.assertTrue(eventer.close.call_count)
         self.assertTrue(eventer.close.call_count)
         self.assertTrue(heart.closed)
         self.assertTrue(heart.closed)
 
 
@@ -277,10 +292,11 @@ class test_Consumer(Case):
         l.event_dispatcher = Mock()
         l.event_dispatcher = Mock()
         l.pidbox_node = MockNode()
         l.pidbox_node = MockNode()
 
 
-        l.receive_message(m.decode(), m)
+        callback = self._get_on_message(l)
+        callback(m.decode(), m)
         self.assertTrue(warn.call_count)
         self.assertTrue(warn.call_count)
 
 
-    @patch('celery.utils.timer2.to_timestamp')
+    @patch('celery.worker.consumer.to_timestamp')
     def test_receive_message_eta_OverflowError(self, to_timestamp):
     def test_receive_message_eta_OverflowError(self, to_timestamp):
         to_timestamp.side_effect = OverflowError()
         to_timestamp.side_effect = OverflowError()
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
@@ -292,7 +308,8 @@ class test_Consumer(Case):
         l.pidbox_node = MockNode()
         l.pidbox_node = MockNode()
         l.update_strategies()
         l.update_strategies()
 
 
-        l.receive_message(m.decode(), m)
+        callback = self._get_on_message(l)
+        callback(m.decode(), m)
         self.assertTrue(m.acknowledged)
         self.assertTrue(m.acknowledged)
         self.assertTrue(to_timestamp.call_count)
         self.assertTrue(to_timestamp.call_count)
 
 
@@ -303,9 +320,9 @@ class test_Consumer(Case):
                            args=(1, 2), kwargs='foobarbaz', id=1)
                            args=(1, 2), kwargs='foobarbaz', id=1)
         l.update_strategies()
         l.update_strategies()
         l.event_dispatcher = Mock()
         l.event_dispatcher = Mock()
-        l.pidbox_node = MockNode()
 
 
-        l.receive_message(m.decode(), m)
+        callback = self._get_on_message(l)
+        callback(m.decode(), m)
         self.assertIn('Received invalid task message', error.call_args[0][0])
         self.assertIn('Received invalid task message', error.call_args[0][0])
 
 
     @patch('celery.worker.consumer.crit')
     @patch('celery.worker.consumer.crit')
@@ -322,14 +339,25 @@ class test_Consumer(Case):
         self.assertTrue(message.ack.call_count)
         self.assertTrue(message.ack.call_count)
         self.assertIn("Can't decode message body", crit.call_args[0][0])
         self.assertIn("Can't decode message body", crit.call_args[0][0])
 
 
+    def _get_on_message(self, l):
+        l.qos = Mock()
+        l.event_dispatcher = Mock()
+        l.task_consumer = Mock()
+        l.connection = Mock()
+        l.connection.drain_events.side_effect = SystemExit()
+
+        with self.assertRaises(SystemExit):
+            l.loop(*l.loop_args())
+        self.assertTrue(l.task_consumer.register_callback.called)
+        return l.task_consumer.register_callback.call_args[0][0]
+
     def test_receieve_message(self):
     def test_receieve_message(self):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         m = create_message(Mock(), task=foo_task.name,
         m = create_message(Mock(), task=foo_task.name,
                            args=[2, 4, 8], kwargs={})
                            args=[2, 4, 8], kwargs={})
         l.update_strategies()
         l.update_strategies()
-
-        l.event_dispatcher = Mock()
-        l.receive_message(m.decode(), m)
+        callback = self._get_on_message(l)
+        callback(m.decode(), m)
 
 
         in_bucket = self.ready_queue.get_nowait()
         in_bucket = self.ready_queue.get_nowait()
         self.assertIsInstance(in_bucket, Request)
         self.assertIsInstance(in_bucket, Request)
@@ -339,10 +367,10 @@ class test_Consumer(Case):
 
 
     def test_start_connection_error(self):
     def test_start_connection_error(self):
 
 
-        class MockConsumer(BlockingConsumer):
+        class MockConsumer(Consumer):
             iterations = 0
             iterations = 0
 
 
-            def consume_messages(self):
+            def loop(self, *args, **kwargs):
                 if not self.iterations:
                 if not self.iterations:
                     self.iterations = 1
                     self.iterations = 1
                     raise KeyError('foo')
                     raise KeyError('foo')
@@ -360,10 +388,10 @@ class test_Consumer(Case):
         # Regression test for AMQPChannelExceptions that can occur within the
         # Regression test for AMQPChannelExceptions that can occur within the
         # consumer. (i.e. 404 errors)
         # consumer. (i.e. 404 errors)
 
 
-        class MockConsumer(BlockingConsumer):
+        class MockConsumer(Consumer):
             iterations = 0
             iterations = 0
 
 
-            def consume_messages(self):
+            def loop(self, *args, **kwargs):
                 if not self.iterations:
                 if not self.iterations:
                     self.iterations = 1
                     self.iterations = 1
                     raise KeyError('foo')
                     raise KeyError('foo')
@@ -377,7 +405,7 @@ class test_Consumer(Case):
         l.heart.stop()
         l.heart.stop()
         l.timer.stop()
         l.timer.stop()
 
 
-    def test_consume_messages_ignores_socket_timeout(self):
+    def test_loop_ignores_socket_timeout(self):
 
 
         class Connection(current_app.connection().__class__):
         class Connection(current_app.connection().__class__):
             obj = None
             obj = None
@@ -391,9 +419,9 @@ class test_Consumer(Case):
         l.task_consumer = Mock()
         l.task_consumer = Mock()
         l.connection.obj = l
         l.connection.obj = l
         l.qos = QoS(l.task_consumer, 10)
         l.qos = QoS(l.task_consumer, 10)
-        l.consume_messages()
+        l.loop(*l.loop_args())
 
 
-    def test_consume_messages_when_socket_error(self):
+    def test_loop_when_socket_error(self):
 
 
         class Connection(current_app.connection().__class__):
         class Connection(current_app.connection().__class__):
             obj = None
             obj = None
@@ -403,19 +431,19 @@ class test_Consumer(Case):
                 raise socket.error('foo')
                 raise socket.error('foo')
 
 
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
-        l._state = RUN
+        l.namespace.state = RUN
         c = l.connection = Connection()
         c = l.connection = Connection()
         l.connection.obj = l
         l.connection.obj = l
         l.task_consumer = Mock()
         l.task_consumer = Mock()
         l.qos = QoS(l.task_consumer, 10)
         l.qos = QoS(l.task_consumer, 10)
         with self.assertRaises(socket.error):
         with self.assertRaises(socket.error):
-            l.consume_messages()
+            l.loop(*l.loop_args())
 
 
-        l._state = CLOSE
+        l.namespace.state = CLOSE
         l.connection = c
         l.connection = c
-        l.consume_messages()
+        l.loop(*l.loop_args())
 
 
-    def test_consume_messages(self):
+    def test_loop(self):
 
 
         class Connection(current_app.connection().__class__):
         class Connection(current_app.connection().__class__):
             obj = None
             obj = None
@@ -429,8 +457,8 @@ class test_Consumer(Case):
         l.task_consumer = Mock()
         l.task_consumer = Mock()
         l.qos = QoS(l.task_consumer, 10)
         l.qos = QoS(l.task_consumer, 10)
 
 
-        l.consume_messages()
-        l.consume_messages()
+        l.loop(*l.loop_args())
+        l.loop(*l.loop_args())
         self.assertTrue(l.task_consumer.consume.call_count)
         self.assertTrue(l.task_consumer.consume.call_count)
         l.task_consumer.qos.assert_called_with(prefetch_count=10)
         l.task_consumer.qos.assert_called_with(prefetch_count=10)
         l.task_consumer.qos = Mock()
         l.task_consumer.qos = Mock()
@@ -470,12 +498,13 @@ class test_Consumer(Case):
                            args=[2, 4, 8], kwargs={})
                            args=[2, 4, 8], kwargs={})
 
 
         l.task_consumer = Mock()
         l.task_consumer = Mock()
-        l.qos = QoS(l.task_consumer, l.initial_prefetch_count)
+        l.qos = QoS(l.task_consumer, 1)
         current_pcount = l.qos.value
         current_pcount = l.qos.value
         l.event_dispatcher = Mock()
         l.event_dispatcher = Mock()
         l.enabled = False
         l.enabled = False
         l.update_strategies()
         l.update_strategies()
-        l.receive_message(m.decode(), m)
+        callback = self._get_on_message(l)
+        callback(m.decode(), m)
         l.timer.stop()
         l.timer.stop()
         l.timer.join(1)
         l.timer.join(1)
 
 
@@ -490,22 +519,23 @@ class test_Consumer(Case):
 
 
     def test_on_control(self):
     def test_on_control(self):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
-        l.pidbox_node = Mock()
-        l.reset_pidbox_node = Mock()
+        con = find_component(l, consumer_components.Controller)
+        con.pidbox_node = Mock()
+        con.reset_pidbox_node = Mock()
 
 
-        l.on_control('foo', 'bar')
-        l.pidbox_node.handle_message.assert_called_with('foo', 'bar')
+        con.on_control('foo', 'bar')
+        con.pidbox_node.handle_message.assert_called_with('foo', 'bar')
 
 
-        l.pidbox_node = Mock()
-        l.pidbox_node.handle_message.side_effect = KeyError('foo')
-        l.on_control('foo', 'bar')
-        l.pidbox_node.handle_message.assert_called_with('foo', 'bar')
+        con.pidbox_node = Mock()
+        con.pidbox_node.handle_message.side_effect = KeyError('foo')
+        con.on_control('foo', 'bar')
+        con.pidbox_node.handle_message.assert_called_with('foo', 'bar')
 
 
-        l.pidbox_node = Mock()
-        l.pidbox_node.handle_message.side_effect = ValueError('foo')
-        l.on_control('foo', 'bar')
-        l.pidbox_node.handle_message.assert_called_with('foo', 'bar')
-        l.reset_pidbox_node.assert_called_with()
+        con.pidbox_node = Mock()
+        con.pidbox_node.handle_message.side_effect = ValueError('foo')
+        con.on_control('foo', 'bar')
+        con.pidbox_node.handle_message.assert_called_with('foo', 'bar')
+        con.reset_pidbox_node.assert_called_with()
 
 
     def test_revoke(self):
     def test_revoke(self):
         ready_queue = FastQueue()
         ready_queue = FastQueue()
@@ -517,7 +547,8 @@ class test_Consumer(Case):
         from celery.worker.state import revoked
         from celery.worker.state import revoked
         revoked.add(id)
         revoked.add(id)
 
 
-        l.receive_message(t.decode(), t)
+        callback = self._get_on_message(l)
+        callback(t.decode(), t)
         self.assertTrue(ready_queue.empty())
         self.assertTrue(ready_queue.empty())
 
 
     def test_receieve_message_not_registered(self):
     def test_receieve_message_not_registered(self):
@@ -526,7 +557,8 @@ class test_Consumer(Case):
         m = create_message(backend, task='x.X.31x', args=[2, 4, 8], kwargs={})
         m = create_message(backend, task='x.X.31x', args=[2, 4, 8], kwargs={})
 
 
         l.event_dispatcher = Mock()
         l.event_dispatcher = Mock()
-        self.assertFalse(l.receive_message(m.decode(), m))
+        callback = self._get_on_message(l)
+        self.assertFalse(callback(m.decode(), m))
         with self.assertRaises(Empty):
         with self.assertRaises(Empty):
             self.ready_queue.get_nowait()
             self.ready_queue.get_nowait()
         self.assertTrue(self.timer.empty())
         self.assertTrue(self.timer.empty())
@@ -542,7 +574,8 @@ class test_Consumer(Case):
         l.connection_errors = (socket.error, )
         l.connection_errors = (socket.error, )
         m.reject = Mock()
         m.reject = Mock()
         m.reject.side_effect = socket.error('foo')
         m.reject.side_effect = socket.error('foo')
-        self.assertFalse(l.receive_message(m.decode(), m))
+        callback = self._get_on_message(l)
+        self.assertFalse(callback(m.decode(), m))
         self.assertTrue(warn.call_count)
         self.assertTrue(warn.call_count)
         with self.assertRaises(Empty):
         with self.assertRaises(Empty):
             self.ready_queue.get_nowait()
             self.ready_queue.get_nowait()
@@ -560,16 +593,17 @@ class test_Consumer(Case):
                            eta=(datetime.now() +
                            eta=(datetime.now() +
                                timedelta(days=1)).isoformat())
                                timedelta(days=1)).isoformat())
 
 
-        l.reset_connection()
+        l.namespace.start(l)
         p = l.app.conf.BROKER_CONNECTION_RETRY
         p = l.app.conf.BROKER_CONNECTION_RETRY
         l.app.conf.BROKER_CONNECTION_RETRY = False
         l.app.conf.BROKER_CONNECTION_RETRY = False
         try:
         try:
-            l.reset_connection()
+            l.namespace.start(l)
         finally:
         finally:
             l.app.conf.BROKER_CONNECTION_RETRY = p
             l.app.conf.BROKER_CONNECTION_RETRY = p
-        l.stop_consumers()
+        l.restart()
         l.event_dispatcher = Mock()
         l.event_dispatcher = Mock()
-        l.receive_message(m.decode(), m)
+        callback = self._get_on_message(l)
+        callback(m.decode(), m)
         l.timer.stop()
         l.timer.stop()
         in_hold = l.timer.queue[0]
         in_hold = l.timer.queue[0]
         self.assertEqual(len(in_hold), 3)
         self.assertEqual(len(in_hold), 3)
@@ -583,24 +617,27 @@ class test_Consumer(Case):
 
 
     def test_reset_pidbox_node(self):
     def test_reset_pidbox_node(self):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
-        l.pidbox_node = Mock()
-        chan = l.pidbox_node.channel = Mock()
+        con = find_component(l, consumer_components.Controller)
+        con.pidbox_node = Mock()
+        chan = con.pidbox_node.channel = Mock()
         l.connection = Mock()
         l.connection = Mock()
         chan.close.side_effect = socket.error('foo')
         chan.close.side_effect = socket.error('foo')
         l.connection_errors = (socket.error, )
         l.connection_errors = (socket.error, )
-        l.reset_pidbox_node()
+        con.reset_pidbox_node()
         chan.close.assert_called_with()
         chan.close.assert_called_with()
 
 
     def test_reset_pidbox_node_green(self):
     def test_reset_pidbox_node_green(self):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        con = find_component(l, consumer_components.Controller)
         l.pool = Mock()
         l.pool = Mock()
         l.pool.is_green = True
         l.pool.is_green = True
-        l.reset_pidbox_node()
-        l.pool.spawn_n.assert_called_with(l._green_pidbox_node)
+        con.reset_pidbox_node()
+        l.pool.spawn_n.assert_called_with(con._green_pidbox_node)
 
 
     def test__green_pidbox_node(self):
     def test__green_pidbox_node(self):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         l.pidbox_node = Mock()
         l.pidbox_node = Mock()
+        cont = find_component(l, consumer_components.Controller)
 
 
         class BConsumer(Mock):
         class BConsumer(Mock):
 
 
@@ -611,7 +648,7 @@ class test_Consumer(Case):
             def __exit__(self, *exc_info):
             def __exit__(self, *exc_info):
                 self.cancel()
                 self.cancel()
 
 
-        l.pidbox_node.listen = BConsumer()
+        cont.pidbox_node.listen = BConsumer()
         connections = []
         connections = []
 
 
         class Connection(object):
         class Connection(object):
@@ -640,18 +677,19 @@ class test_Consumer(Case):
                     self.calls += 1
                     self.calls += 1
                     raise socket.timeout()
                     raise socket.timeout()
                 self.obj.connection = None
                 self.obj.connection = None
-                self.obj._pidbox_node_shutdown.set()
+                cont._pidbox_node_shutdown.set()
 
 
             def close(self):
             def close(self):
                 self.closed = True
                 self.closed = True
 
 
         l.connection = Mock()
         l.connection = Mock()
         l._open_connection = lambda: Connection(obj=l)
         l._open_connection = lambda: Connection(obj=l)
-        l._green_pidbox_node()
+        controller = find_component(l, consumer_components.Controller)
+        controller._green_pidbox_node()
 
 
-        l.pidbox_node.listen.assert_called_with(callback=l.on_control)
-        self.assertTrue(l.broadcast_consumer)
-        l.broadcast_consumer.consume.assert_called_with()
+        cont.pidbox_node.listen.assert_called_with(callback=cont.on_control)
+        self.assertTrue(cont.broadcast_consumer)
+        cont.broadcast_consumer.consume.assert_called_with()
 
 
         self.assertIsNone(l.connection)
         self.assertIsNone(l.connection)
         self.assertTrue(connections[0].closed)
         self.assertTrue(connections[0].closed)
@@ -673,12 +711,13 @@ class test_Consumer(Case):
 
 
     def test_stop_pidbox_node(self):
     def test_stop_pidbox_node(self):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
-        l._pidbox_node_stopped = Event()
-        l._pidbox_node_shutdown = Event()
-        l._pidbox_node_stopped.set()
-        l.stop_pidbox_node()
+        cont = find_component(l, consumer_components.Controller)
+        cont._pidbox_node_stopped = Event()
+        cont._pidbox_node_shutdown = Event()
+        cont._pidbox_node_stopped.set()
+        cont.stop_pidbox_node()
 
 
-    def test_start__consume_messages(self):
+    def test_start__loop(self):
 
 
         class _QoS(object):
         class _QoS(object):
             prev = 3
             prev = 3
@@ -703,18 +742,18 @@ class test_Consumer(Case):
         l.connection = Connection()
         l.connection = Connection()
         l.iterations = 0
         l.iterations = 0
 
 
-        def raises_KeyError(limit=None):
+        def raises_KeyError(*args, **kwargs):
             l.iterations += 1
             l.iterations += 1
             if l.qos.prev != l.qos.value:
             if l.qos.prev != l.qos.value:
                 l.qos.update()
                 l.qos.update()
             if l.iterations >= 2:
             if l.iterations >= 2:
                 raise KeyError('foo')
                 raise KeyError('foo')
 
 
-        l.consume_messages = raises_KeyError
+        l.loop = raises_KeyError
         with self.assertRaises(KeyError):
         with self.assertRaises(KeyError):
             l.start()
             l.start()
         self.assertTrue(init_callback.call_count)
         self.assertTrue(init_callback.call_count)
-        self.assertEqual(l.iterations, 1)
+        self.assertEqual(l.iterations, 2)
         self.assertEqual(l.qos.prev, l.qos.value)
         self.assertEqual(l.qos.prev, l.qos.value)
 
 
         init_callback.reset_mock()
         init_callback.reset_mock()
@@ -724,25 +763,25 @@ class test_Consumer(Case):
         l.task_consumer = Mock()
         l.task_consumer = Mock()
         l.broadcast_consumer = Mock()
         l.broadcast_consumer = Mock()
         l.connection = Connection()
         l.connection = Connection()
-        l.consume_messages = Mock(side_effect=socket.error('foo'))
+        l.loop = Mock(side_effect=socket.error('foo'))
         with self.assertRaises(socket.error):
         with self.assertRaises(socket.error):
             l.start()
             l.start()
         self.assertTrue(init_callback.call_count)
         self.assertTrue(init_callback.call_count)
-        self.assertTrue(l.consume_messages.call_count)
+        self.assertTrue(l.loop.call_count)
 
 
     def test_reset_connection_with_no_node(self):
     def test_reset_connection_with_no_node(self):
-        l = BlockingConsumer(self.ready_queue, timer=self.timer)
+        l = Consumer(self.ready_queue, timer=self.timer)
         self.assertEqual(None, l.pool)
         self.assertEqual(None, l.pool)
-        l.reset_connection()
+        l.namespace.start(l)
 
 
     def test_on_task_revoked(self):
     def test_on_task_revoked(self):
-        l = BlockingConsumer(self.ready_queue, timer=self.timer)
+        l = Consumer(self.ready_queue, timer=self.timer)
         task = Mock()
         task = Mock()
         task.revoked.return_value = True
         task.revoked.return_value = True
         l.on_task(task)
         l.on_task(task)
 
 
     def test_on_task_no_events(self):
     def test_on_task_no_events(self):
-        l = BlockingConsumer(self.ready_queue, timer=self.timer)
+        l = Consumer(self.ready_queue, timer=self.timer)
         task = Mock()
         task = Mock()
         task.revoked.return_value = False
         task.revoked.return_value = False
         l.event_dispatcher = Mock()
         l.event_dispatcher = Mock()

+ 5 - 0
celery/worker/__init__.py

@@ -123,12 +123,17 @@ class WorkController(configurated):
         # Initialize boot steps
         # Initialize boot steps
         self.pool_cls = _concurrency.get_implementation(self.pool_cls)
         self.pool_cls = _concurrency.get_implementation(self.pool_cls)
         self.components = []
         self.components = []
+        self.on_init_namespace()
         self.namespace = Namespace(app=self.app,
         self.namespace = Namespace(app=self.app,
                                    on_start=self.on_start,
                                    on_start=self.on_start,
                                    on_close=self.on_close,
                                    on_close=self.on_close,
                                    on_stopped=self.on_stopped)
                                    on_stopped=self.on_stopped)
         self.namespace.apply(self, **kwargs)
         self.namespace.apply(self, **kwargs)
 
 
+
+    def on_init_namespace(self):
+        pass
+
     def on_before_init(self, **kwargs):
     def on_before_init(self, **kwargs):
         pass
         pass
 
 

+ 21 - 169
celery/worker/consumer/__init__.py

@@ -13,15 +13,10 @@ from __future__ import absolute_import
 import logging
 import logging
 import socket
 import socket
 
 
-from time import sleep
-from Queue import Empty
-
 from kombu.syn import _detect_environment
 from kombu.syn import _detect_environment
 from kombu.utils.encoding import safe_repr
 from kombu.utils.encoding import safe_repr
-from kombu.utils.eventio import READ, WRITE, ERR
 
 
 from celery.app import app_or_default
 from celery.app import app_or_default
-from celery.exceptions import InvalidTaskError, SystemTerminate
 from celery.task.trace import build_tracer
 from celery.task.trace import build_tracer
 from celery.utils.timer2 import default_timer, to_timestamp
 from celery.utils.timer2 import default_timer, to_timestamp
 from celery.utils.functional import noop
 from celery.utils.functional import noop
@@ -30,16 +25,16 @@ from celery.utils.log import get_logger
 from celery.utils.text import dump_body
 from celery.utils.text import dump_body
 from celery.utils.timeutils import humanize_seconds
 from celery.utils.timeutils import humanize_seconds
 from celery.worker import state
 from celery.worker import state
+from celery.worker.state import maybe_shutdown
 from celery.worker.bootsteps import Namespace as _NS, StartStopComponent, CLOSE
 from celery.worker.bootsteps import Namespace as _NS, StartStopComponent, CLOSE
 
 
+from . import loops
+
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 info, warn, error, crit = (logger.info, logger.warn,
 info, warn, error, crit = (logger.info, logger.warn,
                            logger.error, logger.critical)
                            logger.error, logger.critical)
 task_reserved = state.task_reserved
 task_reserved = state.task_reserved
 
 
-#: Heartbeat check is called every heartbeat_seconds' / rate'.
-AMQHEARTBEAT_RATE = 2.0
-
 CONNECTION_RETRY = """\
 CONNECTION_RETRY = """\
 consumer: Connection to broker lost. \
 consumer: Connection to broker lost. \
 Trying to re-establish the connection...\
 Trying to re-establish the connection...\
@@ -102,13 +97,9 @@ class Component(StartStopComponent):
     name = 'worker.consumer'
     name = 'worker.consumer'
     last = True
     last = True
 
 
-    def Consumer(self, w):
-        return (w.consumer_cls or
-                Consumer if w.hub else BlockingConsumer)
-
     def create(self, w):
     def create(self, w):
         prefetch_count = w.concurrency * w.prefetch_multiplier
         prefetch_count = w.concurrency * w.prefetch_multiplier
-        c = w.consumer = self.instantiate(self.Consumer(w),
+        c = w.consumer = self.instantiate(w.consumer_cls,
                 w.ready_queue,
                 w.ready_queue,
                 hostname=w.hostname,
                 hostname=w.hostname,
                 send_events=w.send_events,
                 send_events=w.send_events,
@@ -191,6 +182,8 @@ class Consumer(object):
         self.channel_errors = conninfo.channel_errors
         self.channel_errors = conninfo.channel_errors
 
 
         self._does_info = logger.isEnabledFor(logging.INFO)
         self._does_info = logger.isEnabledFor(logging.INFO)
+        if not hasattr(self, 'loop'):
+            self.loop = loops.asynloop if hub else loops.synloop
         if hub:
         if hub:
             hub.on_init.append(self.on_poll_init)
             hub.on_init.append(self.on_poll_init)
         self.hub = hub
         self.hub = hub
@@ -226,17 +219,24 @@ class Consumer(object):
         consuming messages.
         consuming messages.
 
 
         """
         """
-        ns = self.namespace
+        ns, loop, loop_args = self.namespace, self.loop, self.loop_args()
         while ns.state != CLOSE:
         while ns.state != CLOSE:
-            self.maybe_shutdown()
+            maybe_shutdown()
             try:
             try:
-                self.namespace.start(self)
-                self.consume_messages()
+                ns.start(self)
+                loop(*loop_args)
             except self.connection_errors + self.channel_errors:
             except self.connection_errors + self.channel_errors:
                 error(CONNECTION_RETRY, exc_info=True)
                 error(CONNECTION_RETRY, exc_info=True)
-                ns.restart(self)
-            ns.close(self)
-            ns.state = CLOSE
+                self.restart()
+
+    def loop_args(self):
+        return (self, self.connection, self.task_consumer,
+                self.strategies, self.namespace, self.hub, self.qos,
+                self.amqheartbeat, self.handle_unknown_message,
+                self.handle_unknown_task, self.handle_invalid_task)
+
+    def restart(self):
+        return self.namespace.restart(self)
 
 
     def on_poll_init(self, hub):
     def on_poll_init(self, hub):
         hub.update_readers(self.connection.eventmap)
         hub.update_readers(self.connection.eventmap)
@@ -304,7 +304,7 @@ class Consumer(object):
 
 
         return conn.ensure_connection(_error_handler,
         return conn.ensure_connection(_error_handler,
                     self.app.conf.BROKER_CONNECTION_MAX_RETRIES,
                     self.app.conf.BROKER_CONNECTION_MAX_RETRIES,
-                    callback=self.maybe_shutdown)
+                    callback=maybe_shutdown)
 
 
     def stop(self):
     def stop(self):
         """Stop consuming.
         """Stop consuming.
@@ -315,12 +315,6 @@ class Consumer(object):
         """
         """
         self.namespace.stop(self)
         self.namespace.stop(self)
 
 
-    def maybe_shutdown(self):
-        if state.should_stop:
-            raise SystemExit()
-        elif state.should_terminate:
-            raise SystemTerminate()
-
     def add_task_queue(self, queue, exchange=None, exchange_type=None,
     def add_task_queue(self, queue, exchange=None, exchange_type=None,
             routing_key=None, **options):
             routing_key=None, **options):
         cset = self.task_consumer
         cset = self.task_consumer
@@ -358,106 +352,6 @@ class Consumer(object):
             conninfo.pop('password', None)  # don't send password.
             conninfo.pop('password', None)  # don't send password.
         return {'broker': conninfo, 'prefetch_count': self.qos.value}
         return {'broker': conninfo, 'prefetch_count': self.qos.value}
 
 
-    def consume_messages(self, sleep=sleep, min=min, Empty=Empty,
-            hbrate=AMQHEARTBEAT_RATE):
-        """Consume messages forever (or until an exception is raised)."""
-
-        with self.hub as hub:
-            ns = self.namespace
-            qos = self.qos
-            update_qos = qos.update
-            update_readers = hub.update_readers
-            readers, writers = hub.readers, hub.writers
-            poll = hub.poller.poll
-            fire_timers = hub.fire_timers
-            scheduled = hub.timer._queue
-            connection = self.connection
-            hb = self.amqheartbeat
-            hbtick = connection.heartbeat_check
-            on_poll_start = connection.transport.on_poll_start
-            on_poll_empty = connection.transport.on_poll_empty
-            strategies = self.strategies
-            drain_nowait = connection.drain_nowait
-            on_task_callbacks = hub.on_task
-            keep_draining = connection.transport.nb_keep_draining
-
-            if hb and connection.supports_heartbeats:
-                hub.timer.apply_interval(
-                    hb * 1000.0 / hbrate, hbtick, (hbrate, ))
-
-            def on_task_received(body, message):
-                if on_task_callbacks:
-                    [callback() for callback in on_task_callbacks]
-                try:
-                    name = body['task']
-                except (KeyError, TypeError):
-                    return self.handle_unknown_message(body, message)
-                try:
-                    strategies[name](message, body, message.ack_log_error)
-                except KeyError as exc:
-                    self.handle_unknown_task(body, message, exc)
-                except InvalidTaskError as exc:
-                    self.handle_invalid_task(body, message, exc)
-                #fire_timers()
-
-            self.task_consumer.callbacks = [on_task_received]
-            self.task_consumer.consume()
-
-            debug('Ready to accept tasks!')
-
-            while ns.state != CLOSE and self.connection:
-                # shutdown if signal handlers told us to.
-                if state.should_stop:
-                    raise SystemExit()
-                elif state.should_terminate:
-                    raise SystemTerminate()
-
-                # fire any ready timers, this also returns
-                # the number of seconds until we need to fire timers again.
-                poll_timeout = fire_timers() if scheduled else 1
-
-                # We only update QoS when there is no more messages to read.
-                # This groups together qos calls, and makes sure that remote
-                # control commands will be prioritized over task messages.
-                if qos.prev != qos.value:
-                    update_qos()
-
-                update_readers(on_poll_start())
-                if readers or writers:
-                    connection.more_to_read = True
-                    while connection.more_to_read:
-                        try:
-                            events = poll(poll_timeout)
-                        except ValueError:  # Issue 882
-                            return
-                        if not events:
-                            on_poll_empty()
-                        for fileno, event in events or ():
-                            try:
-                                if event & READ:
-                                    readers[fileno](fileno, event)
-                                if event & WRITE:
-                                    writers[fileno](fileno, event)
-                                if event & ERR:
-                                    for handlermap in readers, writers:
-                                        try:
-                                            handlermap[fileno](fileno, event)
-                                        except KeyError:
-                                            pass
-                            except (KeyError, Empty):
-                                continue
-                            except socket.error:
-                                if ns.state != CLOSE:  # pragma: no cover
-                                    raise
-                        if keep_draining:
-                            drain_nowait()
-                            poll_timeout = 0
-                        else:
-                            connection.more_to_read = False
-                else:
-                    # no sockets yet, startup is probably not done.
-                    sleep(min(poll_timeout, 0.1))
-
     def on_task(self, task, task_reserved=task_reserved):
     def on_task(self, task, task_reserved=task_reserved):
         """Handle received task.
         """Handle received task.
 
 
@@ -527,45 +421,3 @@ class Consumer(object):
         for name, task in self.app.tasks.iteritems():
         for name, task in self.app.tasks.iteritems():
             S[name] = task.start_strategy(app, self)
             S[name] = task.start_strategy(app, self)
             task.__trace__ = build_tracer(name, task, loader, hostname)
             task.__trace__ = build_tracer(name, task, loader, hostname)
-
-
-class BlockingConsumer(Consumer):
-
-    def consume_messages(self):
-        # receive_message handles incoming messages.
-        self.task_consumer.register_callback(self.receive_message)
-        self.task_consumer.consume()
-
-        debug('Ready to accept tasks!')
-        ns = self.namespace
-
-        while ns.state != CLOSE and self.connection:
-            self.maybe_shutdown()
-            if self.qos.prev != self.qos.value:     # pragma: no cover
-                self.qos.update()
-            try:
-                self.connection.drain_events(timeout=10.0)
-            except socket.timeout:
-                pass
-            except socket.error:
-                if ns.state != CLOSE:            # pragma: no cover
-                    raise
-
-    def receive_message(self, body, message):
-        """Handles incoming messages.
-
-        :param body: The message body.
-        :param message: The kombu message object.
-
-        """
-        try:
-            name = body['task']
-        except (KeyError, TypeError):
-            return self.handle_unknown_message(body, message)
-
-        try:
-            self.strategies[name](message, body, message.ack_log_error)
-        except KeyError as exc:
-            self.handle_unknown_task(body, message, exc)
-        except InvalidTaskError as exc:
-            self.handle_invalid_task(body, message, exc)

+ 4 - 5
celery/worker/consumer/components.py

@@ -67,7 +67,7 @@ class Events(StartStopComponent):
                     c.maybe_conn_error(c.event_dispatcher.close)
                     c.maybe_conn_error(c.event_dispatcher.close)
 
 
     def shutdown(self, c):
     def shutdown(self, c):
-        pass
+        self.stop(c)
 
 
 
 
 class Heartbeat(StartStopComponent):
 class Heartbeat(StartStopComponent):
@@ -88,7 +88,7 @@ class Heartbeat(StartStopComponent):
             c.heart = c.heart.stop()
             c.heart = c.heart.stop()
 
 
     def shutdown(self, c):
     def shutdown(self, c):
-        pass
+        self.stop(c)
 
 
 
 
 class Controller(StartStopComponent):
 class Controller(StartStopComponent):
@@ -100,7 +100,6 @@ class Controller(StartStopComponent):
 
 
     def __init__(self, c, **kwargs):
     def __init__(self, c, **kwargs):
         self.app = c.app
         self.app = c.app
-        self.pool = c.pool
         pidbox_state = AttributeDict(
         pidbox_state = AttributeDict(
             app=c.app, hostname=c.hostname, consumer=c,
             app=c.app, hostname=c.hostname, consumer=c,
         )
         )
@@ -141,8 +140,8 @@ class Controller(StartStopComponent):
         if self.pidbox_node and self.pidbox_node.channel:
         if self.pidbox_node and self.pidbox_node.channel:
             c.maybe_conn_error(self.pidbox_node.channel.close)
             c.maybe_conn_error(self.pidbox_node.channel.close)
 
 
-        if self.pool is not None and self.pool.is_green:
-            return self.pool.spawn_n(self._green_pidbox_node)
+        if c.pool is not None and c.pool.is_green:
+            return c.pool.spawn_n(self._green_pidbox_node)
         self.pidbox_node.channel = c.connection.channel()
         self.pidbox_node.channel = c.connection.channel()
         self.broadcast_consumer = self.pidbox_node.listen(
         self.broadcast_consumer = self.pidbox_node.listen(
                                         callback=self.on_control)
                                         callback=self.on_control)

+ 160 - 0
celery/worker/consumer/loops.py

@@ -0,0 +1,160 @@
+"""
+celery.worker.consumer.loop
+~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Worker eventloop.
+
+"""
+from __future__ import absolute_import
+
+import socket
+
+from time import sleep
+from Queue import Empty
+
+from kombu.utils.eventio import READ, WRITE, ERR
+
+from celery.exceptions import InvalidTaskError, SystemTerminate
+from celery.utils.log import get_logger
+from celery.worker import state
+from celery.worker.bootsteps import CLOSE
+
+logger = get_logger(__name__)
+debug = logger.debug
+
+#: Heartbeat check is called every heartbeat_seconds' / rate'.
+AMQHEARTBEAT_RATE = 2.0
+
+
+def asynloop(obj, connection, consumer, strategies, namespace, hub, qos,
+        heartbeat, handle_unknown_message, handle_unknown_task,
+        handle_invalid_task, sleep=sleep, min=min, Empty=Empty,
+        hbrate=AMQHEARTBEAT_RATE):
+    """Non-blocking eventloop consuming messages until connection is lost,
+    or shutdown is requested."""
+
+    with hub as hub:
+        update_qos = qos.update
+        update_readers = hub.update_readers
+        readers, writers = hub.readers, hub.writers
+        poll = hub.poller.poll
+        fire_timers = hub.fire_timers
+        scheduled = hub.timer._queue
+        hbtick = connection.heartbeat_check
+        on_poll_start = connection.transport.on_poll_start
+        on_poll_empty = connection.transport.on_poll_empty
+        drain_nowait = connection.drain_nowait
+        on_task_callbacks = hub.on_task
+        keep_draining = connection.transport.nb_keep_draining
+
+        if heartbeat and connection.supports_heartbeats:
+            hub.timer.apply_interval(
+                heartbeat * 1000.0 / hbrate, hbtick, (hbrate, ))
+
+        def on_task_received(body, message):
+            if on_task_callbacks:
+                [callback() for callback in on_task_callbacks]
+            try:
+                name = body['task']
+            except (KeyError, TypeError):
+                return handle_unknown_message(body, message)
+            try:
+                strategies[name](message, body, message.ack_log_error)
+            except KeyError as exc:
+                handle_unknown_task(body, message, exc)
+            except InvalidTaskError as exc:
+                handle_invalid_task(body, message, exc)
+
+        consumer.callbacks = [on_task_received]
+        consumer.consume()
+
+        debug('Ready to accept tasks!')
+
+        while namespace.state != CLOSE and obj.connection:
+            # shutdown if signal handlers told us to.
+            if state.should_stop:
+                raise SystemExit()
+            elif state.should_terminate:
+                raise SystemTerminate()
+
+            # fire any ready timers, this also returns
+            # the number of seconds until we need to fire timers again.
+            poll_timeout = fire_timers() if scheduled else 1
+
+            # We only update QoS when there is no more messages to read.
+            # This groups together qos calls, and makes sure that remote
+            # control commands will be prioritized over task messages.
+            if qos.prev != qos.value:
+                update_qos()
+
+            update_readers(on_poll_start())
+            if readers or writers:
+                connection.more_to_read = True
+                while connection.more_to_read:
+                    try:
+                        events = poll(poll_timeout)
+                    except ValueError:  # Issue 882
+                        return
+                    if not events:
+                        on_poll_empty()
+                    for fileno, event in events or ():
+                        try:
+                            if event & READ:
+                                readers[fileno](fileno, event)
+                            if event & WRITE:
+                                writers[fileno](fileno, event)
+                            if event & ERR:
+                                for handlermap in readers, writers:
+                                    try:
+                                        handlermap[fileno](fileno, event)
+                                    except KeyError:
+                                        pass
+                        except (KeyError, Empty):
+                            continue
+                        except socket.error:
+                            if namespace.state != CLOSE:  # pragma: no cover
+                                raise
+                    if keep_draining:
+                        drain_nowait()
+                        poll_timeout = 0
+                    else:
+                        connection.more_to_read = False
+            else:
+                # no sockets yet, startup is probably not done.
+                sleep(min(poll_timeout, 0.1))
+
+
+def synloop(obj, connection, consumer, strategies, namespace, hub, qos,
+        heartbeat, handle_unknown_message, handle_unknown_task,
+        handle_invalid_task, **kwargs):
+    """Fallback blocking eventloop for transports that doesn't support AIO."""
+
+    def on_task_received(body, message):
+        try:
+            name = body['task']
+        except (KeyError, TypeError):
+            return handle_unknown_message(body, message)
+
+        try:
+            strategies[name](message, body, message.ack_log_error)
+        except KeyError as exc:
+            handle_unknown_task(body, message, exc)
+        except InvalidTaskError as exc:
+            handle_invalid_task(body, message, exc)
+
+    consumer.register_callback(on_task_received)
+    consumer.consume()
+
+    debug('Ready to accept tasks!')
+
+    while namespace.state != CLOSE and obj.connection:
+        state.maybe_shutdown()
+        if qos.prev != qos.value:         # pragma: no cover
+            qos.update()
+        try:
+            connection.drain_events(timeout=2.0)
+        except socket.timeout:
+            pass
+        except socket.error:
+            if namespace.state != CLOSE:  # pragma: no cover
+                raise

+ 7 - 0
celery/worker/state.py

@@ -53,6 +53,13 @@ should_stop = False
 should_terminate = False
 should_terminate = False
 
 
 
 
+def maybe_shutdown():
+    if should_stop:
+        raise SystemExit()
+    elif should_terminate:
+        raise SystemTerminate()
+
+
 def task_accepted(request):
 def task_accepted(request):
     """Updates global state when a task has been accepted."""
     """Updates global state when a task has been accepted."""
     active_requests.add(request)
     active_requests.add(request)