Ver código fonte

[Worker] Moves each consumer bootstep into separated module.

Ask Solem 9 anos atrás
pai
commit
c31b5cf54d

+ 32 - 38
celery/tests/worker/test_consumer.py

@@ -7,16 +7,12 @@ from billiard.exceptions import RestartFreqExceeded
 
 
 from celery.datastructures import LimitedSet
 from celery.datastructures import LimitedSet
 from celery.worker import state as worker_state
 from celery.worker import state as worker_state
-from celery.worker.consumer import (
-    Consumer,
-    Heart,
-    Tasks,
-    Agent,
-    Mingle,
-    Gossip,
-    dump_body,
-    CLOSE,
-)
+from celery.worker.consumer.agent import Agent
+from celery.worker.consumer.consumer import CLOSE, Consumer, dump_body
+from celery.worker.consumer.gossip import Gossip
+from celery.worker.consumer.heart import Heart
+from celery.worker.consumer.mingle import Mingle
+from celery.worker.consumer.tasks import Tasks
 
 
 from celery.tests.case import AppCase, ContextMock, Mock, SkipTest, call, patch
 from celery.tests.case import AppCase, ContextMock, Mock, SkipTest, call, patch
 
 
@@ -65,19 +61,19 @@ class test_Consumer(AppCase):
         self.assertEqual(c.amqheartbeat, 20)
         self.assertEqual(c.amqheartbeat, 20)
 
 
     def test_gevent_bug_disables_connection_timeout(self):
     def test_gevent_bug_disables_connection_timeout(self):
-        with patch('celery.worker.consumer._detect_environment') as de:
-            de.return_value = 'gevent'
+        with patch('celery.worker.consumer.consumer._detect_environment') as d:
+            d.return_value = 'gevent'
             self.app.conf.broker_connection_timeout = 33.33
             self.app.conf.broker_connection_timeout = 33.33
             self.get_consumer()
             self.get_consumer()
             self.assertIsNone(self.app.conf.broker_connection_timeout)
             self.assertIsNone(self.app.conf.broker_connection_timeout)
 
 
     def test_limit_moved_to_pool(self):
     def test_limit_moved_to_pool(self):
-        with patch('celery.worker.consumer.task_reserved') as reserved:
+        with patch('celery.worker.consumer.consumer.task_reserved') as reserv:
             c = self.get_consumer()
             c = self.get_consumer()
             c.on_task_request = Mock(name='on_task_request')
             c.on_task_request = Mock(name='on_task_request')
             request = Mock(name='request')
             request = Mock(name='request')
             c._limit_move_to_pool(request)
             c._limit_move_to_pool(request)
-            reserved.assert_called_with(request)
+            reserv.assert_called_with(request)
             c.on_task_request.assert_called_with(request)
             c.on_task_request.assert_called_with(request)
 
 
     def test_update_prefetch_count(self):
     def test_update_prefetch_count(self):
@@ -112,17 +108,17 @@ class test_Consumer(AppCase):
     def test_limit_task(self):
     def test_limit_task(self):
         c = self.get_consumer()
         c = self.get_consumer()
 
 
-        with patch('celery.worker.consumer.task_reserved') as reserved:
+        with patch('celery.worker.consumer.consumer.task_reserved') as reserv:
             bucket = Mock()
             bucket = Mock()
             request = Mock()
             request = Mock()
             bucket.can_consume.return_value = True
             bucket.can_consume.return_value = True
 
 
             c._limit_task(request, bucket, 3)
             c._limit_task(request, bucket, 3)
             bucket.can_consume.assert_called_with(3)
             bucket.can_consume.assert_called_with(3)
-            reserved.assert_called_with(request)
+            reserv.assert_called_with(request)
             c.on_task_request.assert_called_with(request)
             c.on_task_request.assert_called_with(request)
 
 
-        with patch('celery.worker.consumer.task_reserved') as reserved:
+        with patch('celery.worker.consumer.consumer.task_reserved') as reserv:
             bucket.can_consume.return_value = False
             bucket.can_consume.return_value = False
             bucket.expected_time.return_value = 3.33
             bucket.expected_time.return_value = 3.33
             limit_order = c._limit_order
             limit_order = c._limit_order
@@ -134,7 +130,7 @@ class test_Consumer(AppCase):
                 priority=c._limit_order,
                 priority=c._limit_order,
             )
             )
             bucket.expected_time.assert_called_with(4)
             bucket.expected_time.assert_called_with(4)
-            self.assertFalse(reserved.called)
+            self.assertFalse(reserv.called)
 
 
     def test_start_blueprint_raises_EMFILE(self):
     def test_start_blueprint_raises_EMFILE(self):
         c = self.get_consumer()
         c = self.get_consumer()
@@ -153,7 +149,7 @@ class test_Consumer(AppCase):
         c._restart_state.step.side_effect = se
         c._restart_state.step.side_effect = se
         c.blueprint.start.side_effect = socket.error()
         c.blueprint.start.side_effect = socket.error()
 
 
-        with patch('celery.worker.consumer.sleep') as sleep:
+        with patch('celery.worker.consumer.consumer.sleep') as sleep:
             c.start()
             c.start()
             sleep.assert_called_with(1)
             sleep.assert_called_with(1)
 
 
@@ -182,12 +178,12 @@ class test_Consumer(AppCase):
         c.register_with_event_loop(Mock(name='loop'))
         c.register_with_event_loop(Mock(name='loop'))
 
 
     def test_on_close_clears_semaphore_timer_and_reqs(self):
     def test_on_close_clears_semaphore_timer_and_reqs(self):
-        with patch('celery.worker.consumer.reserved_requests') as reserved:
+        with patch('celery.worker.consumer.consumer.reserved_requests') as reserv:
             c = self.get_consumer()
             c = self.get_consumer()
             c.on_close()
             c.on_close()
             c.controller.semaphore.clear.assert_called_with()
             c.controller.semaphore.clear.assert_called_with()
             c.timer.clear.assert_called_with()
             c.timer.clear.assert_called_with()
-            reserved.clear.assert_called_with()
+            reserv.clear.assert_called_with()
             c.pool.flush.assert_called_with()
             c.pool.flush.assert_called_with()
 
 
             c.controller = None
             c.controller = None
@@ -375,18 +371,16 @@ class test_Gossip(AppCase):
         c.app.connection_for_read = _amqp_connection()
         c.app.connection_for_read = _amqp_connection()
         g = Gossip(c)
         g = Gossip(c)
         g.start(c)
         g.start(c)
-
-        with patch('celery.worker.consumer.signature') as signature:
-            sig = signature.return_value = Mock()
-            task = Mock()
+        signature = g.app.signature = Mock(name='app.signature')
+        task = Mock()
+        g.call_task(task)
+        signature.assert_called_with(task)
+        signature.return_value.apply_async.assert_called_with()
+
+        signature.return_value.apply_async.side_effect = MemoryError()
+        with patch('celery.worker.consumer.gossip.error') as error:
             g.call_task(task)
             g.call_task(task)
-            signature.assert_called_with(task, app=c.app)
-            sig.apply_async.assert_called_with()
-
-            sig.apply_async.side_effect = MemoryError()
-            with patch('celery.worker.consumer.error') as error:
-                g.call_task(task)
-                self.assertTrue(error.called)
+            self.assertTrue(error.called)
 
 
     def Event(self, id='id', clock=312,
     def Event(self, id='id', clock=312,
               hostname='foo@example.com', pid=4312,
               hostname='foo@example.com', pid=4312,
@@ -414,7 +408,7 @@ class test_Gossip(AppCase):
         g.dispatcher.send.assert_called_with('worker-elect-ack', id='id1')
         g.dispatcher.send.assert_called_with('worker-elect-ack', id='id1')
 
 
         event.pop('clock')
         event.pop('clock')
-        with patch('celery.worker.consumer.error') as error:
+        with patch('celery.worker.consumer.gossip.error') as error:
             g.on_elect(event)
             g.on_elect(event)
             self.assertTrue(error.called)
             self.assertTrue(error.called)
 
 
@@ -444,7 +438,7 @@ class test_Gossip(AppCase):
         g.on_elect(e3)
         g.on_elect(e3)
         self.assertEqual(len(g.consensus_requests['id1']), 3)
         self.assertEqual(len(g.consensus_requests['id1']), 3)
 
 
-        with patch('celery.worker.consumer.info'):
+        with patch('celery.worker.consumer.gossip.info'):
             g.on_elect_ack(e1)
             g.on_elect_ack(e1)
             self.assertEqual(len(g.consensus_replies['id1']), 1)
             self.assertEqual(len(g.consensus_replies['id1']), 1)
             g.on_elect_ack(e2)
             g.on_elect_ack(e2)
@@ -474,7 +468,7 @@ class test_Gossip(AppCase):
         c.app.connection_for_read = _amqp_connection()
         c.app.connection_for_read = _amqp_connection()
         g = Gossip(c)
         g = Gossip(c)
         g.election_handlers = {}
         g.election_handlers = {}
-        with patch('celery.worker.consumer.error') as error:
+        with patch('celery.worker.consumer.gossip.error') as error:
             self.setup_election(g, c)
             self.setup_election(g, c)
             self.assertTrue(error.called)
             self.assertTrue(error.called)
 
 
@@ -482,7 +476,7 @@ class test_Gossip(AppCase):
         c = self.Consumer()
         c = self.Consumer()
         c.app.connection_for_read = _amqp_connection()
         c.app.connection_for_read = _amqp_connection()
         g = Gossip(c)
         g = Gossip(c)
-        with patch('celery.worker.consumer.debug') as debug:
+        with patch('celery.worker.consumer.gossip.debug') as debug:
             g.on_node_join(c)
             g.on_node_join(c)
             debug.assert_called_with('%s joined the party', 'foo@x.com')
             debug.assert_called_with('%s joined the party', 'foo@x.com')
 
 
@@ -490,7 +484,7 @@ class test_Gossip(AppCase):
         c = self.Consumer()
         c = self.Consumer()
         c.app.connection_for_read = _amqp_connection()
         c.app.connection_for_read = _amqp_connection()
         g = Gossip(c)
         g = Gossip(c)
-        with patch('celery.worker.consumer.debug') as debug:
+        with patch('celery.worker.consumer.gossip.debug') as debug:
             g.on_node_leave(c)
             g.on_node_leave(c)
             debug.assert_called_with('%s left', 'foo@x.com')
             debug.assert_called_with('%s left', 'foo@x.com')
 
 
@@ -498,7 +492,7 @@ class test_Gossip(AppCase):
         c = self.Consumer()
         c = self.Consumer()
         c.app.connection_for_read = _amqp_connection()
         c.app.connection_for_read = _amqp_connection()
         g = Gossip(c)
         g = Gossip(c)
-        with patch('celery.worker.consumer.info') as info:
+        with patch('celery.worker.consumer.gossip.info') as info:
             g.on_node_lost(c)
             g.on_node_lost(c)
             info.assert_called_with('missed heartbeat from %s', 'foo@x.com')
             info.assert_called_with('missed heartbeat from %s', 'foo@x.com')
 
 

+ 5 - 5
celery/tests/worker/test_worker.py

@@ -214,7 +214,7 @@ class test_Consumer(AppCase):
         self.assertTrue(eventer.close.call_count)
         self.assertTrue(eventer.close.call_count)
         self.assertTrue(heart.closed)
         self.assertTrue(heart.closed)
 
 
-    @patch('celery.worker.consumer.warn')
+    @patch('celery.worker.consumer.consumer.warn')
     def test_receive_message_unknown(self, warn):
     def test_receive_message_unknown(self, warn):
         l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.blueprint.state = RUN
         l.blueprint.state = RUN
@@ -250,7 +250,7 @@ class test_Consumer(AppCase):
         callback(m)
         callback(m)
         self.assertTrue(m.acknowledged)
         self.assertTrue(m.acknowledged)
 
 
-    @patch('celery.worker.consumer.error')
+    @patch('celery.worker.consumer.consumer.error')
     def test_receive_message_InvalidTaskError(self, error):
     def test_receive_message_InvalidTaskError(self, error):
         l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.blueprint.state = RUN
         l.blueprint.state = RUN
@@ -271,7 +271,7 @@ class test_Consumer(AppCase):
         self.assertTrue(error.called)
         self.assertTrue(error.called)
         self.assertIn('Received invalid task message', error.call_args[0][0])
         self.assertIn('Received invalid task message', error.call_args[0][0])
 
 
-    @patch('celery.worker.consumer.crit')
+    @patch('celery.worker.consumer.consumer.crit')
     def test_on_decode_error(self, crit):
     def test_on_decode_error(self, crit):
         l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
         l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
 
 
@@ -531,8 +531,8 @@ class test_Consumer(AppCase):
             self.buffer.get_nowait()
             self.buffer.get_nowait()
         self.assertTrue(self.timer.empty())
         self.assertTrue(self.timer.empty())
 
 
-    @patch('celery.worker.consumer.warn')
-    @patch('celery.worker.consumer.logger')
+    @patch('celery.worker.consumer.consumer.warn')
+    @patch('celery.worker.consumer.consumer.logger')
     def test_receieve_message_ack_raises(self, logger, warn):
     def test_receieve_message_ack_raises(self, logger, warn):
         l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
         l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
         l.controller = l.app.WorkController()
         l.controller = l.app.WorkController()

+ 17 - 0
celery/worker/consumer/__init__.py

@@ -0,0 +1,17 @@
+from __future__ import absolute_import, unicode_literals
+
+from .consumer import Consumer
+
+from .agent import Agent
+from .connection import Connection
+from .control import Control
+from .events import Events
+from .gossip import Gossip
+from .heart import Heart
+from .mingle import Mingle
+from .tasks import Tasks
+
+__all__ = [
+    'Consumer', 'Agent', 'Connection', 'Control',
+    'Events', 'Gossip', 'Heart', 'Mingle', 'Tasks',
+]

+ 20 - 0
celery/worker/consumer/agent.py

@@ -0,0 +1,20 @@
+from __future__ import absolute_import, unicode_literals
+
+from celery import bootsteps
+
+from .connection import Connection
+
+__all__ = ['Agent']
+
+
+class Agent(bootsteps.StartStopStep):
+
+    conditional = True
+    requires = (Connection,)
+
+    def __init__(self, c, **kwargs):
+        self.agent_cls = self.enabled = c.app.conf.worker_agent
+
+    def create(self, c):
+        agent = c.agent = self.instantiate(self.agent_cls, c.connection)
+        return agent

+ 33 - 0
celery/worker/consumer/connection.py

@@ -0,0 +1,33 @@
+from __future__ import absolute_import, unicode_literals
+
+from kombu.common import ignore_errors
+
+from celery import bootsteps
+from celery.utils.log import get_logger
+
+__all__ = ['Connection']
+logger = get_logger(__name__)
+info = logger.info
+
+
+class Connection(bootsteps.StartStopStep):
+
+    def __init__(self, c, **kwargs):
+        c.connection = None
+
+    def start(self, c):
+        c.connection = c.connect()
+        info('Connected to %s', c.connection.as_uri())
+
+    def shutdown(self, c):
+        # We must set self.connection to None here, so
+        # that the green pidbox thread exits.
+        connection, c.connection = c.connection, None
+        if connection:
+            ignore_errors(connection, connection.close)
+
+    def info(self, c, params='N/A'):
+        if c.connection:
+            params = c.connection.info()
+            params.pop('password', None)  # don't send password.
+        return {'broker': params}

+ 17 - 395
celery/worker/consumer.py → celery/worker/consumer/consumer.py

@@ -11,22 +11,17 @@ up and running.
 from __future__ import absolute_import
 from __future__ import absolute_import
 
 
 import errno
 import errno
-import kombu
 import logging
 import logging
 import os
 import os
 
 
 from collections import defaultdict
 from collections import defaultdict
-from functools import partial
-from heapq import heappush
-from operator import itemgetter
 from time import sleep
 from time import sleep
 
 
 from amqp.promise import ppartial, promise
 from amqp.promise import ppartial, promise
 from billiard.common import restart_state
 from billiard.common import restart_state
 from billiard.exceptions import RestartFreqExceeded
 from billiard.exceptions import RestartFreqExceeded
 from kombu.async.semaphore import DummyLock
 from kombu.async.semaphore import DummyLock
-from kombu.common import QoS, ignore_errors
-from kombu.five import buffer_t, items, values
+from kombu.five import buffer_t, items
 from kombu.syn import _detect_environment
 from kombu.syn import _detect_environment
 from kombu.utils.encoding import safe_repr, bytes_t
 from kombu.utils.encoding import safe_repr, bytes_t
 from kombu.utils.limits import TokenBucket
 from kombu.utils.limits import TokenBucket
@@ -34,22 +29,19 @@ from kombu.utils.limits import TokenBucket
 from celery import bootsteps
 from celery import bootsteps
 from celery import signals
 from celery import signals
 from celery.app.trace import build_tracer
 from celery.app.trace import build_tracer
-from celery.canvas import signature
 from celery.exceptions import InvalidTaskError, NotRegistered
 from celery.exceptions import InvalidTaskError, NotRegistered
 from celery.utils import gethostname
 from celery.utils import gethostname
 from celery.utils.functional import noop
 from celery.utils.functional import noop
 from celery.utils.log import get_logger
 from celery.utils.log import get_logger
-from celery.utils.objects import Bunch
 from celery.utils.text import truncate
 from celery.utils.text import truncate
 from celery.utils.timeutils import humanize_seconds, rate
 from celery.utils.timeutils import humanize_seconds, rate
 
 
-from . import heartbeat, loops, pidbox
-from .state import task_reserved, maybe_shutdown, revoked, reserved_requests
+from celery.worker import loops
+from celery.worker.state import (
+    task_reserved, maybe_shutdown, reserved_requests,
+)
 
 
-__all__ = [
-    'Consumer', 'Connection', 'Events', 'Heart', 'Control',
-    'Tasks', 'Evloop', 'Agent', 'Mingle', 'Gossip', 'dump_body',
-]
+__all__ = ['Consumer', 'Evloop', 'dump_body']
 
 
 CLOSE = bootsteps.CLOSE
 CLOSE = bootsteps.CLOSE
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -117,8 +109,6 @@ body: {0}
   delivery_info:{3} headers={4}}}
   delivery_info:{3} headers={4}}}
 """
 """
 
 
-MINGLE_GET_FIELDS = itemgetter('clock', 'revoked')
-
 
 
 def dump_body(m, body):
 def dump_body(m, body):
     # v2 protocol does not deserialize body
     # v2 protocol does not deserialize body
@@ -130,6 +120,7 @@ def dump_body(m, body):
 
 
 
 
 class Consumer(object):
 class Consumer(object):
+
     Strategies = dict
     Strategies = dict
 
 
     #: set when consumer is shutting down.
     #: set when consumer is shutting down.
@@ -151,15 +142,15 @@ class Consumer(object):
     class Blueprint(bootsteps.Blueprint):
     class Blueprint(bootsteps.Blueprint):
         name = 'Consumer'
         name = 'Consumer'
         default_steps = [
         default_steps = [
-            'celery.worker.consumer:Connection',
-            'celery.worker.consumer:Mingle',
-            'celery.worker.consumer:Events',
-            'celery.worker.consumer:Gossip',
-            'celery.worker.consumer:Heart',
-            'celery.worker.consumer:Control',
-            'celery.worker.consumer:Tasks',
-            'celery.worker.consumer:Evloop',
-            'celery.worker.consumer:Agent',
+            'celery.worker.consumer.connection:Connection',
+            'celery.worker.consumer.mingle:Mingle',
+            'celery.worker.consumer.events:Events',
+            'celery.worker.consumer.gossip:Gossip',
+            'celery.worker.consumer.heart:Heart',
+            'celery.worker.consumer.control:Control',
+            'celery.worker.consumer.tasks:Tasks',
+            'celery.worker.consumer.consumer:Evloop',
+            'celery.worker.consumer.agent:Agent',
         ]
         ]
 
 
         def shutdown(self, parent):
         def shutdown(self, parent):
@@ -538,377 +529,8 @@ class Consumer(object):
         )
         )
 
 
 
 
-class Connection(bootsteps.StartStopStep):
-
-    def __init__(self, c, **kwargs):
-        c.connection = None
-
-    def start(self, c):
-        c.connection = c.connect()
-        info('Connected to %s', c.connection.as_uri())
-
-    def shutdown(self, c):
-        # We must set self.connection to None here, so
-        # that the green pidbox thread exits.
-        connection, c.connection = c.connection, None
-        if connection:
-            ignore_errors(connection, connection.close)
-
-    def info(self, c, params='N/A'):
-        if c.connection:
-            params = c.connection.info()
-            params.pop('password', None)  # don't send password.
-        return {'broker': params}
-
-
-class Events(bootsteps.StartStopStep):
-    requires = (Connection,)
-
-    def __init__(self, c, send_events=True,
-                 without_heartbeat=False, without_gossip=False, **kwargs):
-        self.groups = None if send_events else ['worker']
-        self.send_events = (
-            send_events or
-            not without_gossip or
-            not without_heartbeat
-        )
-        c.event_dispatcher = None
-
-    def start(self, c):
-        # flush events sent while connection was down.
-        prev = self._close(c)
-        dis = c.event_dispatcher = c.app.events.Dispatcher(
-            c.connect(), hostname=c.hostname,
-            enabled=self.send_events, groups=self.groups,
-            buffer_group=['task'] if c.hub else None,
-            on_send_buffered=c.on_send_event_buffered if c.hub else None,
-        )
-        if prev:
-            dis.extend_buffer(prev)
-            dis.flush()
-
-    def stop(self, c):
-        pass
-
-    def _close(self, c):
-        if c.event_dispatcher:
-            dispatcher = c.event_dispatcher
-            # remember changes from remote control commands:
-            self.groups = dispatcher.groups
-
-            # close custom connection
-            if dispatcher.connection:
-                ignore_errors(c, dispatcher.connection.close)
-            ignore_errors(c, dispatcher.close)
-            c.event_dispatcher = None
-            return dispatcher
-
-    def shutdown(self, c):
-        self._close(c)
-
-
-class Heart(bootsteps.StartStopStep):
-    requires = (Events,)
-
-    def __init__(self, c, without_heartbeat=False, heartbeat_interval=None,
-                 **kwargs):
-        self.enabled = not without_heartbeat
-        self.heartbeat_interval = heartbeat_interval
-        c.heart = None
-
-    def start(self, c):
-        c.heart = heartbeat.Heart(
-            c.timer, c.event_dispatcher, self.heartbeat_interval,
-        )
-        c.heart.start()
-
-    def stop(self, c):
-        c.heart = c.heart and c.heart.stop()
-    shutdown = stop
-
-
-class Mingle(bootsteps.StartStopStep):
-    label = 'Mingle'
-    requires = (Events,)
-    compatible_transports = {'amqp', 'redis'}
-
-    def __init__(self, c, without_mingle=False, **kwargs):
-        self.enabled = not without_mingle and self.compatible_transport(c.app)
-
-    def compatible_transport(self, app):
-        with app.connection_for_read() as conn:
-            return conn.transport.driver_type in self.compatible_transports
-
-    def start(self, c):
-        info('mingle: searching for neighbors')
-        I = c.app.control.inspect(timeout=1.0, connection=c.connection)
-        replies = I.hello(c.hostname, revoked._data) or {}
-        replies.pop(c.hostname, None)
-        if replies:
-            info('mingle: sync with %s nodes',
-                 len([reply for reply, value in items(replies) if value]))
-            for reply in values(replies):
-                if reply:
-                    try:
-                        other_clock, other_revoked = MINGLE_GET_FIELDS(reply)
-                    except KeyError:  # reply from pre-3.1 worker
-                        pass
-                    else:
-                        c.app.clock.adjust(other_clock)
-                        revoked.update(other_revoked)
-            info('mingle: sync complete')
-        else:
-            info('mingle: all alone')
-
-
-class Tasks(bootsteps.StartStopStep):
-    requires = (Mingle,)
-
-    def __init__(self, c, **kwargs):
-        c.task_consumer = c.qos = None
-
-    def start(self, c):
-        c.update_strategies()
-
-        # - RabbitMQ 3.3 completely redefines how basic_qos works..
-        # This will detect if the new qos smenatics is in effect,
-        # and if so make sure the 'apply_global' flag is set on qos updates.
-        qos_global = not c.connection.qos_semantics_matches_spec
-
-        # set initial prefetch count
-        c.connection.default_channel.basic_qos(
-            0, c.initial_prefetch_count, qos_global,
-        )
-
-        c.task_consumer = c.app.amqp.TaskConsumer(
-            c.connection, on_decode_error=c.on_decode_error,
-        )
-
-        def set_prefetch_count(prefetch_count):
-            return c.task_consumer.qos(
-                prefetch_count=prefetch_count,
-                apply_global=qos_global,
-            )
-        c.qos = QoS(set_prefetch_count, c.initial_prefetch_count)
-
-    def stop(self, c):
-        if c.task_consumer:
-            debug('Cancelling task consumer...')
-            ignore_errors(c, c.task_consumer.cancel)
-
-    def shutdown(self, c):
-        if c.task_consumer:
-            self.stop(c)
-            debug('Closing consumer channel...')
-            ignore_errors(c, c.task_consumer.close)
-            c.task_consumer = None
-
-    def info(self, c):
-        return {'prefetch_count': c.qos.value if c.qos else 'N/A'}
-
-
-class Agent(bootsteps.StartStopStep):
-    conditional = True
-    requires = (Connection,)
-
-    def __init__(self, c, **kwargs):
-        self.agent_cls = self.enabled = c.app.conf.worker_agent
-
-    def create(self, c):
-        agent = c.agent = self.instantiate(self.agent_cls, c.connection)
-        return agent
-
-
-class Control(bootsteps.StartStopStep):
-    requires = (Tasks,)
-
-    def __init__(self, c, **kwargs):
-        self.is_green = c.pool is not None and c.pool.is_green
-        self.box = (pidbox.gPidbox if self.is_green else pidbox.Pidbox)(c)
-        self.start = self.box.start
-        self.stop = self.box.stop
-        self.shutdown = self.box.shutdown
-
-    def include_if(self, c):
-        return (c.app.conf.worker_enable_remote_control and
-                c.conninfo.supports_exchange_type('fanout'))
-
-
-class Gossip(bootsteps.ConsumerStep):
-    label = 'Gossip'
-    requires = (Mingle,)
-    _cons_stamp_fields = itemgetter(
-        'id', 'clock', 'hostname', 'pid', 'topic', 'action', 'cver',
-    )
-    compatible_transports = {'amqp', 'redis'}
-
-    def __init__(self, c, without_gossip=False,
-                 interval=5.0, heartbeat_interval=2.0, **kwargs):
-        self.enabled = not without_gossip and self.compatible_transport(c.app)
-        self.app = c.app
-        c.gossip = self
-        self.Receiver = c.app.events.Receiver
-        self.hostname = c.hostname
-        self.full_hostname = '.'.join([self.hostname, str(c.pid)])
-        self.on = Bunch(
-            node_join=set(),
-            node_leave=set(),
-            node_lost=set(),
-        )
-
-        self.timer = c.timer
-        if self.enabled:
-            self.state = c.app.events.State(
-                on_node_join=self.on_node_join,
-                on_node_leave=self.on_node_leave,
-                max_tasks_in_memory=1,
-            )
-            if c.hub:
-                c._mutex = DummyLock()
-            self.update_state = self.state.event
-        self.interval = interval
-        self.heartbeat_interval = heartbeat_interval
-        self._tref = None
-        self.consensus_requests = defaultdict(list)
-        self.consensus_replies = {}
-        self.event_handlers = {
-            'worker.elect': self.on_elect,
-            'worker.elect.ack': self.on_elect_ack,
-        }
-        self.clock = c.app.clock
-
-        self.election_handlers = {
-            'task': self.call_task
-        }
-
-    def compatible_transport(self, app):
-        with app.connection_for_read() as conn:
-            return conn.transport.driver_type in self.compatible_transports
-
-    def election(self, id, topic, action=None):
-        self.consensus_replies[id] = []
-        self.dispatcher.send(
-            'worker-elect',
-            id=id, topic=topic, action=action, cver=1,
-        )
-
-    def call_task(self, task):
-        try:
-            signature(task, app=self.app).apply_async()
-        except Exception as exc:
-            error('Could not call task: %r', exc, exc_info=1)
-
-    def on_elect(self, event):
-        try:
-            (id_, clock, hostname, pid,
-             topic, action, _) = self._cons_stamp_fields(event)
-        except KeyError as exc:
-            return error('election request missing field %s', exc, exc_info=1)
-        heappush(
-            self.consensus_requests[id_],
-            (clock, '%s.%s' % (hostname, pid), topic, action),
-        )
-        self.dispatcher.send('worker-elect-ack', id=id_)
-
-    def start(self, c):
-        super(Gossip, self).start(c)
-        self.dispatcher = c.event_dispatcher
-
-    def on_elect_ack(self, event):
-        id = event['id']
-        try:
-            replies = self.consensus_replies[id]
-        except KeyError:
-            return  # not for us
-        alive_workers = self.state.alive_workers()
-        replies.append(event['hostname'])
-
-        if len(replies) >= len(alive_workers):
-            _, leader, topic, action = self.clock.sort_heap(
-                self.consensus_requests[id],
-            )
-            if leader == self.full_hostname:
-                info('I won the election %r', id)
-                try:
-                    handler = self.election_handlers[topic]
-                except KeyError:
-                    error('Unknown election topic %r', topic, exc_info=1)
-                else:
-                    handler(action)
-            else:
-                info('node %s elected for %r', leader, id)
-            self.consensus_requests.pop(id, None)
-            self.consensus_replies.pop(id, None)
-
-    def on_node_join(self, worker):
-        debug('%s joined the party', worker.hostname)
-        self._call_handlers(self.on.node_join, worker)
-
-    def on_node_leave(self, worker):
-        debug('%s left', worker.hostname)
-        self._call_handlers(self.on.node_leave, worker)
-
-    def on_node_lost(self, worker):
-        info('missed heartbeat from %s', worker.hostname)
-        self._call_handlers(self.on.node_lost, worker)
-
-    def _call_handlers(self, handlers, *args, **kwargs):
-        for handler in handlers:
-            try:
-                handler(*args, **kwargs)
-            except Exception as exc:
-                error('Ignored error from handler %r: %r',
-                      handler, exc, exc_info=1)
-
-    def register_timer(self):
-        if self._tref is not None:
-            self._tref.cancel()
-        self._tref = self.timer.call_repeatedly(self.interval, self.periodic)
-
-    def periodic(self):
-        workers = self.state.workers
-        dirty = set()
-        for worker in values(workers):
-            if not worker.alive:
-                dirty.add(worker)
-                self.on_node_lost(worker)
-        for worker in dirty:
-            workers.pop(worker.hostname, None)
-
-    def get_consumers(self, channel):
-        self.register_timer()
-        ev = self.Receiver(channel, routing_key='worker.#',
-                           queue_ttl=self.heartbeat_interval)
-        return [kombu.Consumer(
-            channel,
-            queues=[ev.queue],
-            on_message=partial(self.on_message, ev.event_from_message),
-            no_ack=True
-        )]
-
-    def on_message(self, prepare, message):
-        _type = message.delivery_info['routing_key']
-
-        # For redis when `fanout_patterns=False` (See Issue #1882)
-        if _type.split('.', 1)[0] == 'task':
-            return
-        try:
-            handler = self.event_handlers[_type]
-        except KeyError:
-            pass
-        else:
-            return handler(message.payload)
-
-        hostname = (message.headers.get('hostname') or
-                    message.payload['hostname'])
-        if hostname != self.hostname:
-            type, event = prepare(message.payload)
-            self.update_state(event)
-        else:
-            self.clock.forward()
-
-
 class Evloop(bootsteps.StartStopStep):
 class Evloop(bootsteps.StartStopStep):
+
     label = 'event loop'
     label = 'event loop'
     last = True
     last = True
 
 

+ 27 - 0
celery/worker/consumer/control.py

@@ -0,0 +1,27 @@
+from __future__ import absolute_import, unicode_literals
+
+from celery import bootsteps
+from celery.utils.log import get_logger
+
+from celery.worker import pidbox
+
+from .tasks import Tasks
+
+__all__ = ['Control']
+logger = get_logger(__name__)
+
+
+class Control(bootsteps.StartStopStep):
+
+    requires = (Tasks,)
+
+    def __init__(self, c, **kwargs):
+        self.is_green = c.pool is not None and c.pool.is_green
+        self.box = (pidbox.gPidbox if self.is_green else pidbox.Pidbox)(c)
+        self.start = self.box.start
+        self.stop = self.box.stop
+        self.shutdown = self.box.shutdown
+
+    def include_if(self, c):
+        return (c.app.conf.worker_enable_remote_control and
+                c.conninfo.supports_exchange_type('fanout'))

+ 56 - 0
celery/worker/consumer/events.py

@@ -0,0 +1,56 @@
+from __future__ import absolute_import, unicode_literals
+
+from kombu.common import ignore_errors
+
+from celery import bootsteps
+
+from .connection import Connection
+
+__all__ = ['Events']
+
+
+class Events(bootsteps.StartStopStep):
+
+    requires = (Connection,)
+
+    def __init__(self, c, send_events=True,
+                 without_heartbeat=False, without_gossip=False, **kwargs):
+        self.groups = None if send_events else ['worker']
+        self.send_events = (
+            send_events or
+            not without_gossip or
+            not without_heartbeat
+        )
+        c.event_dispatcher = None
+
+    def start(self, c):
+        # flush events sent while connection was down.
+        prev = self._close(c)
+        dis = c.event_dispatcher = c.app.events.Dispatcher(
+            c.connect(), hostname=c.hostname,
+            enabled=self.send_events, groups=self.groups,
+            buffer_group=['task'] if c.hub else None,
+            on_send_buffered=c.on_send_event_buffered if c.hub else None,
+        )
+        if prev:
+            dis.extend_buffer(prev)
+            dis.flush()
+
+    def stop(self, c):
+        pass
+
+    def _close(self, c):
+        if c.event_dispatcher:
+            dispatcher = c.event_dispatcher
+            # remember changes from remote control commands:
+            self.groups = dispatcher.groups
+
+            # close custom connection
+            if dispatcher.connection:
+                ignore_errors(c, dispatcher.connection.close)
+            ignore_errors(c, dispatcher.close)
+            c.event_dispatcher = None
+            return dispatcher
+
+    def shutdown(self, c):
+        self._close(c)

+ 195 - 0
celery/worker/consumer/gossip.py

@@ -0,0 +1,195 @@
+from __future__ import absolute_import, unicode_literals
+
+from collections import defaultdict
+from functools import partial
+from heapq import heappush
+from operator import itemgetter
+
+from kombu import Consumer
+from kombu.async.semaphore import DummyLock
+
+from celery import bootsteps
+from celery.five import values
+from celery.utils.log import get_logger
+from celery.utils.objects import Bunch
+
+from .mingle import Mingle
+
+__all__ = ['Gossip']
+logger = get_logger(__name__)
+debug, info, error = logger.debug, logger.info, logger.error
+
+
+class Gossip(bootsteps.ConsumerStep):
+
+    label = 'Gossip'
+    requires = (Mingle,)
+    _cons_stamp_fields = itemgetter(
+        'id', 'clock', 'hostname', 'pid', 'topic', 'action', 'cver',
+    )
+    compatible_transports = {'amqp', 'redis'}
+
+    def __init__(self, c, without_gossip=False,
+                 interval=5.0, heartbeat_interval=2.0, **kwargs):
+        self.enabled = not without_gossip and self.compatible_transport(c.app)
+        self.app = c.app
+        c.gossip = self
+        self.Receiver = c.app.events.Receiver
+        self.hostname = c.hostname
+        self.full_hostname = '.'.join([self.hostname, str(c.pid)])
+        self.on = Bunch(
+            node_join=set(),
+            node_leave=set(),
+            node_lost=set(),
+        )
+
+        self.timer = c.timer
+        if self.enabled:
+            self.state = c.app.events.State(
+                on_node_join=self.on_node_join,
+                on_node_leave=self.on_node_leave,
+                max_tasks_in_memory=1,
+            )
+            if c.hub:
+                c._mutex = DummyLock()
+            self.update_state = self.state.event
+        self.interval = interval
+        self.heartbeat_interval = heartbeat_interval
+        self._tref = None
+        self.consensus_requests = defaultdict(list)
+        self.consensus_replies = {}
+        self.event_handlers = {
+            'worker.elect': self.on_elect,
+            'worker.elect.ack': self.on_elect_ack,
+        }
+        self.clock = c.app.clock
+
+        self.election_handlers = {
+            'task': self.call_task
+        }
+
+    def compatible_transport(self, app):
+        with app.connection_for_read() as conn:
+            return conn.transport.driver_type in self.compatible_transports
+
+    def election(self, id, topic, action=None):
+        self.consensus_replies[id] = []
+        self.dispatcher.send(
+            'worker-elect',
+            id=id, topic=topic, action=action, cver=1,
+        )
+
+    def call_task(self, task):
+        try:
+            self.app.signature(task).apply_async()
+        except Exception as exc:
+            error('Could not call task: %r', exc, exc_info=1)
+
+    def on_elect(self, event):
+        try:
+            (id_, clock, hostname, pid,
+             topic, action, _) = self._cons_stamp_fields(event)
+        except KeyError as exc:
+            return error('election request missing field %s', exc, exc_info=1)
+        heappush(
+            self.consensus_requests[id_],
+            (clock, '%s.%s' % (hostname, pid), topic, action),
+        )
+        self.dispatcher.send('worker-elect-ack', id=id_)
+
+    def start(self, c):
+        super(Gossip, self).start(c)
+        self.dispatcher = c.event_dispatcher
+
+    def on_elect_ack(self, event):
+        id = event['id']
+        try:
+            replies = self.consensus_replies[id]
+        except KeyError:
+            return  # not for us
+        alive_workers = self.state.alive_workers()
+        replies.append(event['hostname'])
+
+        if len(replies) >= len(alive_workers):
+            _, leader, topic, action = self.clock.sort_heap(
+                self.consensus_requests[id],
+            )
+            if leader == self.full_hostname:
+                info('I won the election %r', id)
+                try:
+                    handler = self.election_handlers[topic]
+                except KeyError:
+                    error('Unknown election topic %r', topic, exc_info=1)
+                else:
+                    handler(action)
+            else:
+                info('node %s elected for %r', leader, id)
+            self.consensus_requests.pop(id, None)
+            self.consensus_replies.pop(id, None)
+
+    def on_node_join(self, worker):
+        debug('%s joined the party', worker.hostname)
+        self._call_handlers(self.on.node_join, worker)
+
+    def on_node_leave(self, worker):
+        debug('%s left', worker.hostname)
+        self._call_handlers(self.on.node_leave, worker)
+
+    def on_node_lost(self, worker):
+        info('missed heartbeat from %s', worker.hostname)
+        self._call_handlers(self.on.node_lost, worker)
+
+    def _call_handlers(self, handlers, *args, **kwargs):
+        for handler in handlers:
+            try:
+                handler(*args, **kwargs)
+            except Exception as exc:
+                error('Ignored error from handler %r: %r',
+                      handler, exc, exc_info=1)
+
+    def register_timer(self):
+        if self._tref is not None:
+            self._tref.cancel()
+        self._tref = self.timer.call_repeatedly(self.interval, self.periodic)
+
+    def periodic(self):
+        workers = self.state.workers
+        dirty = set()
+        for worker in values(workers):
+            if not worker.alive:
+                dirty.add(worker)
+                self.on_node_lost(worker)
+        for worker in dirty:
+            workers.pop(worker.hostname, None)
+
+    def get_consumers(self, channel):
+        self.register_timer()
+        ev = self.Receiver(channel, routing_key='worker.#',
+                           queue_ttl=self.heartbeat_interval)
+        return [Consumer(
+            channel,
+            queues=[ev.queue],
+            on_message=partial(self.on_message, ev.event_from_message),
+            no_ack=True
+        )]
+
+    def on_message(self, prepare, message):
+        _type = message.delivery_info['routing_key']
+
+        # For redis when `fanout_patterns=False` (See Issue #1882)
+        if _type.split('.', 1)[0] == 'task':
+            return
+        try:
+            handler = self.event_handlers[_type]
+        except KeyError:
+            pass
+        else:
+            return handler(message.payload)
+
+        hostname = (message.headers.get('hostname') or
+                    message.payload['hostname'])
+        if hostname != self.hostname:
+            type, event = prepare(message.payload)
+            self.update_state(event)
+        else:
+            self.clock.forward()

+ 30 - 0
celery/worker/consumer/heart.py

@@ -0,0 +1,30 @@
+from __future__ import absolute_import, unicode_literals
+
+from celery import bootsteps
+
+from celery.worker import heartbeat
+
+from .events import Events
+
+__all__ = ['Heart']
+
+
+class Heart(bootsteps.StartStopStep):
+
+    requires = (Events,)
+
+    def __init__(self, c,
+                 without_heartbeat=False, heartbeat_interval=None, **kwargs):
+        self.enabled = not without_heartbeat
+        self.heartbeat_interval = heartbeat_interval
+        c.heart = None
+
+    def start(self, c):
+        c.heart = heartbeat.Heart(
+            c.timer, c.event_dispatcher, self.heartbeat_interval,
+        )
+        c.heart.start()
+
+    def stop(self, c):
+        c.heart = c.heart and c.heart.stop()
+    shutdown = stop

+ 53 - 0
celery/worker/consumer/mingle.py

@@ -0,0 +1,53 @@
+from __future__ import absolute_import, unicode_literals
+
+from operator import itemgetter
+
+from celery import bootsteps
+from celery.five import items, values
+from celery.utils.log import get_logger
+
+from celery.worker.state import revoked
+
+from .events import Events
+
+__all__ = ['Mingle']
+
+MINGLE_GET_FIELDS = itemgetter('clock', 'revoked')
+
+logger = get_logger(__name__)
+info = logger.info
+
+
+class Mingle(bootsteps.StartStopStep):
+
+    label = 'Mingle'
+    requires = (Events,)
+    compatible_transports = {'amqp', 'redis'}
+
+    def __init__(self, c, without_mingle=False, **kwargs):
+        self.enabled = not without_mingle and self.compatible_transport(c.app)
+
+    def compatible_transport(self, app):
+        with app.connection_for_read() as conn:
+            return conn.transport.driver_type in self.compatible_transports
+
+    def start(self, c):
+        info('mingle: searching for neighbors')
+        I = c.app.control.inspect(timeout=1.0, connection=c.connection)
+        replies = I.hello(c.hostname, revoked._data) or {}
+        replies.pop(c.hostname, None)
+        if replies:
+            info('mingle: sync with %s nodes',
+                 len([reply for reply, value in items(replies) if value]))
+            for reply in values(replies):
+                if reply:
+                    try:
+                        other_clock, other_revoked = MINGLE_GET_FIELDS(reply)
+                    except KeyError:  # reply from pre-3.1 worker
+                        pass
+                    else:
+                        c.app.clock.adjust(other_clock)
+                        revoked.update(other_revoked)
+            info('mingle: sync complete')
+        else:
+            info('mingle: all alone')

+ 59 - 0
celery/worker/consumer/tasks.py

@@ -0,0 +1,59 @@
+from __future__ import absolute_import, unicode_literals
+
+from kombu.common import QoS, ignore_errors
+
+from celery import bootsteps
+from celery.utils.log import get_logger
+
+from .mingle import Mingle
+
+__all__ = ['Tasks']
+logger = get_logger(__name__)
+debug = logger.debug
+
+
+class Tasks(bootsteps.StartStopStep):
+
+    requires = (Mingle,)
+
+    def __init__(self, c, **kwargs):
+        c.task_consumer = c.qos = None
+
+    def start(self, c):
+        c.update_strategies()
+
+        # - RabbitMQ 3.3 completely redefines how basic_qos works..
+        # This will detect if the new qos smenatics is in effect,
+        # and if so make sure the 'apply_global' flag is set on qos updates.
+        qos_global = not c.connection.qos_semantics_matches_spec
+
+        # set initial prefetch count
+        c.connection.default_channel.basic_qos(
+            0, c.initial_prefetch_count, qos_global,
+        )
+
+        c.task_consumer = c.app.amqp.TaskConsumer(
+            c.connection, on_decode_error=c.on_decode_error,
+        )
+
+        def set_prefetch_count(prefetch_count):
+            return c.task_consumer.qos(
+                prefetch_count=prefetch_count,
+                apply_global=qos_global,
+            )
+        c.qos = QoS(set_prefetch_count, c.initial_prefetch_count)
+
+    def stop(self, c):
+        if c.task_consumer:
+            debug('Cancelling task consumer...')
+            ignore_errors(c, c.task_consumer.cancel)
+
+    def shutdown(self, c):
+        if c.task_consumer:
+            self.stop(c)
+            debug('Closing consumer channel...')
+            ignore_errors(c, c.task_consumer.close)
+            c.task_consumer = None
+
+    def info(self, c):
+        return {'prefetch_count': c.qos.value if c.qos else 'N/A'}