Browse Source

[Worker] Moves each consumer bootstep into separated module.

Ask Solem 9 years ago
parent
commit
c31b5cf54d

+ 32 - 38
celery/tests/worker/test_consumer.py

@@ -7,16 +7,12 @@ from billiard.exceptions import RestartFreqExceeded
 
 from celery.datastructures import LimitedSet
 from celery.worker import state as worker_state
-from celery.worker.consumer import (
-    Consumer,
-    Heart,
-    Tasks,
-    Agent,
-    Mingle,
-    Gossip,
-    dump_body,
-    CLOSE,
-)
+from celery.worker.consumer.agent import Agent
+from celery.worker.consumer.consumer import CLOSE, Consumer, dump_body
+from celery.worker.consumer.gossip import Gossip
+from celery.worker.consumer.heart import Heart
+from celery.worker.consumer.mingle import Mingle
+from celery.worker.consumer.tasks import Tasks
 
 from celery.tests.case import AppCase, ContextMock, Mock, SkipTest, call, patch
 
@@ -65,19 +61,19 @@ class test_Consumer(AppCase):
         self.assertEqual(c.amqheartbeat, 20)
 
     def test_gevent_bug_disables_connection_timeout(self):
-        with patch('celery.worker.consumer._detect_environment') as de:
-            de.return_value = 'gevent'
+        with patch('celery.worker.consumer.consumer._detect_environment') as d:
+            d.return_value = 'gevent'
             self.app.conf.broker_connection_timeout = 33.33
             self.get_consumer()
             self.assertIsNone(self.app.conf.broker_connection_timeout)
 
     def test_limit_moved_to_pool(self):
-        with patch('celery.worker.consumer.task_reserved') as reserved:
+        with patch('celery.worker.consumer.consumer.task_reserved') as reserv:
             c = self.get_consumer()
             c.on_task_request = Mock(name='on_task_request')
             request = Mock(name='request')
             c._limit_move_to_pool(request)
-            reserved.assert_called_with(request)
+            reserv.assert_called_with(request)
             c.on_task_request.assert_called_with(request)
 
     def test_update_prefetch_count(self):
@@ -112,17 +108,17 @@ class test_Consumer(AppCase):
     def test_limit_task(self):
         c = self.get_consumer()
 
-        with patch('celery.worker.consumer.task_reserved') as reserved:
+        with patch('celery.worker.consumer.consumer.task_reserved') as reserv:
             bucket = Mock()
             request = Mock()
             bucket.can_consume.return_value = True
 
             c._limit_task(request, bucket, 3)
             bucket.can_consume.assert_called_with(3)
-            reserved.assert_called_with(request)
+            reserv.assert_called_with(request)
             c.on_task_request.assert_called_with(request)
 
-        with patch('celery.worker.consumer.task_reserved') as reserved:
+        with patch('celery.worker.consumer.consumer.task_reserved') as reserv:
             bucket.can_consume.return_value = False
             bucket.expected_time.return_value = 3.33
             limit_order = c._limit_order
@@ -134,7 +130,7 @@ class test_Consumer(AppCase):
                 priority=c._limit_order,
             )
             bucket.expected_time.assert_called_with(4)
-            self.assertFalse(reserved.called)
+            self.assertFalse(reserv.called)
 
     def test_start_blueprint_raises_EMFILE(self):
         c = self.get_consumer()
@@ -153,7 +149,7 @@ class test_Consumer(AppCase):
         c._restart_state.step.side_effect = se
         c.blueprint.start.side_effect = socket.error()
 
-        with patch('celery.worker.consumer.sleep') as sleep:
+        with patch('celery.worker.consumer.consumer.sleep') as sleep:
             c.start()
             sleep.assert_called_with(1)
 
@@ -182,12 +178,12 @@ class test_Consumer(AppCase):
         c.register_with_event_loop(Mock(name='loop'))
 
     def test_on_close_clears_semaphore_timer_and_reqs(self):
-        with patch('celery.worker.consumer.reserved_requests') as reserved:
+        with patch('celery.worker.consumer.consumer.reserved_requests') as reserv:
             c = self.get_consumer()
             c.on_close()
             c.controller.semaphore.clear.assert_called_with()
             c.timer.clear.assert_called_with()
-            reserved.clear.assert_called_with()
+            reserv.clear.assert_called_with()
             c.pool.flush.assert_called_with()
 
             c.controller = None
@@ -375,18 +371,16 @@ class test_Gossip(AppCase):
         c.app.connection_for_read = _amqp_connection()
         g = Gossip(c)
         g.start(c)
-
-        with patch('celery.worker.consumer.signature') as signature:
-            sig = signature.return_value = Mock()
-            task = Mock()
+        signature = g.app.signature = Mock(name='app.signature')
+        task = Mock()
+        g.call_task(task)
+        signature.assert_called_with(task)
+        signature.return_value.apply_async.assert_called_with()
+
+        signature.return_value.apply_async.side_effect = MemoryError()
+        with patch('celery.worker.consumer.gossip.error') as error:
             g.call_task(task)
-            signature.assert_called_with(task, app=c.app)
-            sig.apply_async.assert_called_with()
-
-            sig.apply_async.side_effect = MemoryError()
-            with patch('celery.worker.consumer.error') as error:
-                g.call_task(task)
-                self.assertTrue(error.called)
+            self.assertTrue(error.called)
 
     def Event(self, id='id', clock=312,
               hostname='foo@example.com', pid=4312,
@@ -414,7 +408,7 @@ class test_Gossip(AppCase):
         g.dispatcher.send.assert_called_with('worker-elect-ack', id='id1')
 
         event.pop('clock')
-        with patch('celery.worker.consumer.error') as error:
+        with patch('celery.worker.consumer.gossip.error') as error:
             g.on_elect(event)
             self.assertTrue(error.called)
 
@@ -444,7 +438,7 @@ class test_Gossip(AppCase):
         g.on_elect(e3)
         self.assertEqual(len(g.consensus_requests['id1']), 3)
 
-        with patch('celery.worker.consumer.info'):
+        with patch('celery.worker.consumer.gossip.info'):
             g.on_elect_ack(e1)
             self.assertEqual(len(g.consensus_replies['id1']), 1)
             g.on_elect_ack(e2)
@@ -474,7 +468,7 @@ class test_Gossip(AppCase):
         c.app.connection_for_read = _amqp_connection()
         g = Gossip(c)
         g.election_handlers = {}
-        with patch('celery.worker.consumer.error') as error:
+        with patch('celery.worker.consumer.gossip.error') as error:
             self.setup_election(g, c)
             self.assertTrue(error.called)
 
@@ -482,7 +476,7 @@ class test_Gossip(AppCase):
         c = self.Consumer()
         c.app.connection_for_read = _amqp_connection()
         g = Gossip(c)
-        with patch('celery.worker.consumer.debug') as debug:
+        with patch('celery.worker.consumer.gossip.debug') as debug:
             g.on_node_join(c)
             debug.assert_called_with('%s joined the party', 'foo@x.com')
 
@@ -490,7 +484,7 @@ class test_Gossip(AppCase):
         c = self.Consumer()
         c.app.connection_for_read = _amqp_connection()
         g = Gossip(c)
-        with patch('celery.worker.consumer.debug') as debug:
+        with patch('celery.worker.consumer.gossip.debug') as debug:
             g.on_node_leave(c)
             debug.assert_called_with('%s left', 'foo@x.com')
 
@@ -498,7 +492,7 @@ class test_Gossip(AppCase):
         c = self.Consumer()
         c.app.connection_for_read = _amqp_connection()
         g = Gossip(c)
-        with patch('celery.worker.consumer.info') as info:
+        with patch('celery.worker.consumer.gossip.info') as info:
             g.on_node_lost(c)
             info.assert_called_with('missed heartbeat from %s', 'foo@x.com')
 

+ 5 - 5
celery/tests/worker/test_worker.py

@@ -214,7 +214,7 @@ class test_Consumer(AppCase):
         self.assertTrue(eventer.close.call_count)
         self.assertTrue(heart.closed)
 
-    @patch('celery.worker.consumer.warn')
+    @patch('celery.worker.consumer.consumer.warn')
     def test_receive_message_unknown(self, warn):
         l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.blueprint.state = RUN
@@ -250,7 +250,7 @@ class test_Consumer(AppCase):
         callback(m)
         self.assertTrue(m.acknowledged)
 
-    @patch('celery.worker.consumer.error')
+    @patch('celery.worker.consumer.consumer.error')
     def test_receive_message_InvalidTaskError(self, error):
         l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.blueprint.state = RUN
@@ -271,7 +271,7 @@ class test_Consumer(AppCase):
         self.assertTrue(error.called)
         self.assertIn('Received invalid task message', error.call_args[0][0])
 
-    @patch('celery.worker.consumer.crit')
+    @patch('celery.worker.consumer.consumer.crit')
     def test_on_decode_error(self, crit):
         l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
 
@@ -531,8 +531,8 @@ class test_Consumer(AppCase):
             self.buffer.get_nowait()
         self.assertTrue(self.timer.empty())
 
-    @patch('celery.worker.consumer.warn')
-    @patch('celery.worker.consumer.logger')
+    @patch('celery.worker.consumer.consumer.warn')
+    @patch('celery.worker.consumer.consumer.logger')
     def test_receieve_message_ack_raises(self, logger, warn):
         l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
         l.controller = l.app.WorkController()

+ 17 - 0
celery/worker/consumer/__init__.py

@@ -0,0 +1,17 @@
+from __future__ import absolute_import, unicode_literals
+
+from .consumer import Consumer
+
+from .agent import Agent
+from .connection import Connection
+from .control import Control
+from .events import Events
+from .gossip import Gossip
+from .heart import Heart
+from .mingle import Mingle
+from .tasks import Tasks
+
+__all__ = [
+    'Consumer', 'Agent', 'Connection', 'Control',
+    'Events', 'Gossip', 'Heart', 'Mingle', 'Tasks',
+]

+ 20 - 0
celery/worker/consumer/agent.py

@@ -0,0 +1,20 @@
+from __future__ import absolute_import, unicode_literals
+
+from celery import bootsteps
+
+from .connection import Connection
+
+__all__ = ['Agent']
+
+
+class Agent(bootsteps.StartStopStep):
+
+    conditional = True
+    requires = (Connection,)
+
+    def __init__(self, c, **kwargs):
+        self.agent_cls = self.enabled = c.app.conf.worker_agent
+
+    def create(self, c):
+        agent = c.agent = self.instantiate(self.agent_cls, c.connection)
+        return agent

+ 33 - 0
celery/worker/consumer/connection.py

@@ -0,0 +1,33 @@
+from __future__ import absolute_import, unicode_literals
+
+from kombu.common import ignore_errors
+
+from celery import bootsteps
+from celery.utils.log import get_logger
+
+__all__ = ['Connection']
+logger = get_logger(__name__)
+info = logger.info
+
+
+class Connection(bootsteps.StartStopStep):
+
+    def __init__(self, c, **kwargs):
+        c.connection = None
+
+    def start(self, c):
+        c.connection = c.connect()
+        info('Connected to %s', c.connection.as_uri())
+
+    def shutdown(self, c):
+        # We must set self.connection to None here, so
+        # that the green pidbox thread exits.
+        connection, c.connection = c.connection, None
+        if connection:
+            ignore_errors(connection, connection.close)
+
+    def info(self, c, params='N/A'):
+        if c.connection:
+            params = c.connection.info()
+            params.pop('password', None)  # don't send password.
+        return {'broker': params}

+ 17 - 395
celery/worker/consumer.py → celery/worker/consumer/consumer.py

@@ -11,22 +11,17 @@ up and running.
 from __future__ import absolute_import
 
 import errno
-import kombu
 import logging
 import os
 
 from collections import defaultdict
-from functools import partial
-from heapq import heappush
-from operator import itemgetter
 from time import sleep
 
 from amqp.promise import ppartial, promise
 from billiard.common import restart_state
 from billiard.exceptions import RestartFreqExceeded
 from kombu.async.semaphore import DummyLock
-from kombu.common import QoS, ignore_errors
-from kombu.five import buffer_t, items, values
+from kombu.five import buffer_t, items
 from kombu.syn import _detect_environment
 from kombu.utils.encoding import safe_repr, bytes_t
 from kombu.utils.limits import TokenBucket
@@ -34,22 +29,19 @@ from kombu.utils.limits import TokenBucket
 from celery import bootsteps
 from celery import signals
 from celery.app.trace import build_tracer
-from celery.canvas import signature
 from celery.exceptions import InvalidTaskError, NotRegistered
 from celery.utils import gethostname
 from celery.utils.functional import noop
 from celery.utils.log import get_logger
-from celery.utils.objects import Bunch
 from celery.utils.text import truncate
 from celery.utils.timeutils import humanize_seconds, rate
 
-from . import heartbeat, loops, pidbox
-from .state import task_reserved, maybe_shutdown, revoked, reserved_requests
+from celery.worker import loops
+from celery.worker.state import (
+    task_reserved, maybe_shutdown, reserved_requests,
+)
 
-__all__ = [
-    'Consumer', 'Connection', 'Events', 'Heart', 'Control',
-    'Tasks', 'Evloop', 'Agent', 'Mingle', 'Gossip', 'dump_body',
-]
+__all__ = ['Consumer', 'Evloop', 'dump_body']
 
 CLOSE = bootsteps.CLOSE
 logger = get_logger(__name__)
@@ -117,8 +109,6 @@ body: {0}
   delivery_info:{3} headers={4}}}
 """
 
-MINGLE_GET_FIELDS = itemgetter('clock', 'revoked')
-
 
 def dump_body(m, body):
     # v2 protocol does not deserialize body
@@ -130,6 +120,7 @@ def dump_body(m, body):
 
 
 class Consumer(object):
+
     Strategies = dict
 
     #: set when consumer is shutting down.
@@ -151,15 +142,15 @@ class Consumer(object):
     class Blueprint(bootsteps.Blueprint):
         name = 'Consumer'
         default_steps = [
-            'celery.worker.consumer:Connection',
-            'celery.worker.consumer:Mingle',
-            'celery.worker.consumer:Events',
-            'celery.worker.consumer:Gossip',
-            'celery.worker.consumer:Heart',
-            'celery.worker.consumer:Control',
-            'celery.worker.consumer:Tasks',
-            'celery.worker.consumer:Evloop',
-            'celery.worker.consumer:Agent',
+            'celery.worker.consumer.connection:Connection',
+            'celery.worker.consumer.mingle:Mingle',
+            'celery.worker.consumer.events:Events',
+            'celery.worker.consumer.gossip:Gossip',
+            'celery.worker.consumer.heart:Heart',
+            'celery.worker.consumer.control:Control',
+            'celery.worker.consumer.tasks:Tasks',
+            'celery.worker.consumer.consumer:Evloop',
+            'celery.worker.consumer.agent:Agent',
         ]
 
         def shutdown(self, parent):
@@ -538,377 +529,8 @@ class Consumer(object):
         )
 
 
-class Connection(bootsteps.StartStopStep):
-
-    def __init__(self, c, **kwargs):
-        c.connection = None
-
-    def start(self, c):
-        c.connection = c.connect()
-        info('Connected to %s', c.connection.as_uri())
-
-    def shutdown(self, c):
-        # We must set self.connection to None here, so
-        # that the green pidbox thread exits.
-        connection, c.connection = c.connection, None
-        if connection:
-            ignore_errors(connection, connection.close)
-
-    def info(self, c, params='N/A'):
-        if c.connection:
-            params = c.connection.info()
-            params.pop('password', None)  # don't send password.
-        return {'broker': params}
-
-
-class Events(bootsteps.StartStopStep):
-    requires = (Connection,)
-
-    def __init__(self, c, send_events=True,
-                 without_heartbeat=False, without_gossip=False, **kwargs):
-        self.groups = None if send_events else ['worker']
-        self.send_events = (
-            send_events or
-            not without_gossip or
-            not without_heartbeat
-        )
-        c.event_dispatcher = None
-
-    def start(self, c):
-        # flush events sent while connection was down.
-        prev = self._close(c)
-        dis = c.event_dispatcher = c.app.events.Dispatcher(
-            c.connect(), hostname=c.hostname,
-            enabled=self.send_events, groups=self.groups,
-            buffer_group=['task'] if c.hub else None,
-            on_send_buffered=c.on_send_event_buffered if c.hub else None,
-        )
-        if prev:
-            dis.extend_buffer(prev)
-            dis.flush()
-
-    def stop(self, c):
-        pass
-
-    def _close(self, c):
-        if c.event_dispatcher:
-            dispatcher = c.event_dispatcher
-            # remember changes from remote control commands:
-            self.groups = dispatcher.groups
-
-            # close custom connection
-            if dispatcher.connection:
-                ignore_errors(c, dispatcher.connection.close)
-            ignore_errors(c, dispatcher.close)
-            c.event_dispatcher = None
-            return dispatcher
-
-    def shutdown(self, c):
-        self._close(c)
-
-
-class Heart(bootsteps.StartStopStep):
-    requires = (Events,)
-
-    def __init__(self, c, without_heartbeat=False, heartbeat_interval=None,
-                 **kwargs):
-        self.enabled = not without_heartbeat
-        self.heartbeat_interval = heartbeat_interval
-        c.heart = None
-
-    def start(self, c):
-        c.heart = heartbeat.Heart(
-            c.timer, c.event_dispatcher, self.heartbeat_interval,
-        )
-        c.heart.start()
-
-    def stop(self, c):
-        c.heart = c.heart and c.heart.stop()
-    shutdown = stop
-
-
-class Mingle(bootsteps.StartStopStep):
-    label = 'Mingle'
-    requires = (Events,)
-    compatible_transports = {'amqp', 'redis'}
-
-    def __init__(self, c, without_mingle=False, **kwargs):
-        self.enabled = not without_mingle and self.compatible_transport(c.app)
-
-    def compatible_transport(self, app):
-        with app.connection_for_read() as conn:
-            return conn.transport.driver_type in self.compatible_transports
-
-    def start(self, c):
-        info('mingle: searching for neighbors')
-        I = c.app.control.inspect(timeout=1.0, connection=c.connection)
-        replies = I.hello(c.hostname, revoked._data) or {}
-        replies.pop(c.hostname, None)
-        if replies:
-            info('mingle: sync with %s nodes',
-                 len([reply for reply, value in items(replies) if value]))
-            for reply in values(replies):
-                if reply:
-                    try:
-                        other_clock, other_revoked = MINGLE_GET_FIELDS(reply)
-                    except KeyError:  # reply from pre-3.1 worker
-                        pass
-                    else:
-                        c.app.clock.adjust(other_clock)
-                        revoked.update(other_revoked)
-            info('mingle: sync complete')
-        else:
-            info('mingle: all alone')
-
-
-class Tasks(bootsteps.StartStopStep):
-    requires = (Mingle,)
-
-    def __init__(self, c, **kwargs):
-        c.task_consumer = c.qos = None
-
-    def start(self, c):
-        c.update_strategies()
-
-        # - RabbitMQ 3.3 completely redefines how basic_qos works..
-        # This will detect if the new qos smenatics is in effect,
-        # and if so make sure the 'apply_global' flag is set on qos updates.
-        qos_global = not c.connection.qos_semantics_matches_spec
-
-        # set initial prefetch count
-        c.connection.default_channel.basic_qos(
-            0, c.initial_prefetch_count, qos_global,
-        )
-
-        c.task_consumer = c.app.amqp.TaskConsumer(
-            c.connection, on_decode_error=c.on_decode_error,
-        )
-
-        def set_prefetch_count(prefetch_count):
-            return c.task_consumer.qos(
-                prefetch_count=prefetch_count,
-                apply_global=qos_global,
-            )
-        c.qos = QoS(set_prefetch_count, c.initial_prefetch_count)
-
-    def stop(self, c):
-        if c.task_consumer:
-            debug('Cancelling task consumer...')
-            ignore_errors(c, c.task_consumer.cancel)
-
-    def shutdown(self, c):
-        if c.task_consumer:
-            self.stop(c)
-            debug('Closing consumer channel...')
-            ignore_errors(c, c.task_consumer.close)
-            c.task_consumer = None
-
-    def info(self, c):
-        return {'prefetch_count': c.qos.value if c.qos else 'N/A'}
-
-
-class Agent(bootsteps.StartStopStep):
-    conditional = True
-    requires = (Connection,)
-
-    def __init__(self, c, **kwargs):
-        self.agent_cls = self.enabled = c.app.conf.worker_agent
-
-    def create(self, c):
-        agent = c.agent = self.instantiate(self.agent_cls, c.connection)
-        return agent
-
-
-class Control(bootsteps.StartStopStep):
-    requires = (Tasks,)
-
-    def __init__(self, c, **kwargs):
-        self.is_green = c.pool is not None and c.pool.is_green
-        self.box = (pidbox.gPidbox if self.is_green else pidbox.Pidbox)(c)
-        self.start = self.box.start
-        self.stop = self.box.stop
-        self.shutdown = self.box.shutdown
-
-    def include_if(self, c):
-        return (c.app.conf.worker_enable_remote_control and
-                c.conninfo.supports_exchange_type('fanout'))
-
-
-class Gossip(bootsteps.ConsumerStep):
-    label = 'Gossip'
-    requires = (Mingle,)
-    _cons_stamp_fields = itemgetter(
-        'id', 'clock', 'hostname', 'pid', 'topic', 'action', 'cver',
-    )
-    compatible_transports = {'amqp', 'redis'}
-
-    def __init__(self, c, without_gossip=False,
-                 interval=5.0, heartbeat_interval=2.0, **kwargs):
-        self.enabled = not without_gossip and self.compatible_transport(c.app)
-        self.app = c.app
-        c.gossip = self
-        self.Receiver = c.app.events.Receiver
-        self.hostname = c.hostname
-        self.full_hostname = '.'.join([self.hostname, str(c.pid)])
-        self.on = Bunch(
-            node_join=set(),
-            node_leave=set(),
-            node_lost=set(),
-        )
-
-        self.timer = c.timer
-        if self.enabled:
-            self.state = c.app.events.State(
-                on_node_join=self.on_node_join,
-                on_node_leave=self.on_node_leave,
-                max_tasks_in_memory=1,
-            )
-            if c.hub:
-                c._mutex = DummyLock()
-            self.update_state = self.state.event
-        self.interval = interval
-        self.heartbeat_interval = heartbeat_interval
-        self._tref = None
-        self.consensus_requests = defaultdict(list)
-        self.consensus_replies = {}
-        self.event_handlers = {
-            'worker.elect': self.on_elect,
-            'worker.elect.ack': self.on_elect_ack,
-        }
-        self.clock = c.app.clock
-
-        self.election_handlers = {
-            'task': self.call_task
-        }
-
-    def compatible_transport(self, app):
-        with app.connection_for_read() as conn:
-            return conn.transport.driver_type in self.compatible_transports
-
-    def election(self, id, topic, action=None):
-        self.consensus_replies[id] = []
-        self.dispatcher.send(
-            'worker-elect',
-            id=id, topic=topic, action=action, cver=1,
-        )
-
-    def call_task(self, task):
-        try:
-            signature(task, app=self.app).apply_async()
-        except Exception as exc:
-            error('Could not call task: %r', exc, exc_info=1)
-
-    def on_elect(self, event):
-        try:
-            (id_, clock, hostname, pid,
-             topic, action, _) = self._cons_stamp_fields(event)
-        except KeyError as exc:
-            return error('election request missing field %s', exc, exc_info=1)
-        heappush(
-            self.consensus_requests[id_],
-            (clock, '%s.%s' % (hostname, pid), topic, action),
-        )
-        self.dispatcher.send('worker-elect-ack', id=id_)
-
-    def start(self, c):
-        super(Gossip, self).start(c)
-        self.dispatcher = c.event_dispatcher
-
-    def on_elect_ack(self, event):
-        id = event['id']
-        try:
-            replies = self.consensus_replies[id]
-        except KeyError:
-            return  # not for us
-        alive_workers = self.state.alive_workers()
-        replies.append(event['hostname'])
-
-        if len(replies) >= len(alive_workers):
-            _, leader, topic, action = self.clock.sort_heap(
-                self.consensus_requests[id],
-            )
-            if leader == self.full_hostname:
-                info('I won the election %r', id)
-                try:
-                    handler = self.election_handlers[topic]
-                except KeyError:
-                    error('Unknown election topic %r', topic, exc_info=1)
-                else:
-                    handler(action)
-            else:
-                info('node %s elected for %r', leader, id)
-            self.consensus_requests.pop(id, None)
-            self.consensus_replies.pop(id, None)
-
-    def on_node_join(self, worker):
-        debug('%s joined the party', worker.hostname)
-        self._call_handlers(self.on.node_join, worker)
-
-    def on_node_leave(self, worker):
-        debug('%s left', worker.hostname)
-        self._call_handlers(self.on.node_leave, worker)
-
-    def on_node_lost(self, worker):
-        info('missed heartbeat from %s', worker.hostname)
-        self._call_handlers(self.on.node_lost, worker)
-
-    def _call_handlers(self, handlers, *args, **kwargs):
-        for handler in handlers:
-            try:
-                handler(*args, **kwargs)
-            except Exception as exc:
-                error('Ignored error from handler %r: %r',
-                      handler, exc, exc_info=1)
-
-    def register_timer(self):
-        if self._tref is not None:
-            self._tref.cancel()
-        self._tref = self.timer.call_repeatedly(self.interval, self.periodic)
-
-    def periodic(self):
-        workers = self.state.workers
-        dirty = set()
-        for worker in values(workers):
-            if not worker.alive:
-                dirty.add(worker)
-                self.on_node_lost(worker)
-        for worker in dirty:
-            workers.pop(worker.hostname, None)
-
-    def get_consumers(self, channel):
-        self.register_timer()
-        ev = self.Receiver(channel, routing_key='worker.#',
-                           queue_ttl=self.heartbeat_interval)
-        return [kombu.Consumer(
-            channel,
-            queues=[ev.queue],
-            on_message=partial(self.on_message, ev.event_from_message),
-            no_ack=True
-        )]
-
-    def on_message(self, prepare, message):
-        _type = message.delivery_info['routing_key']
-
-        # For redis when `fanout_patterns=False` (See Issue #1882)
-        if _type.split('.', 1)[0] == 'task':
-            return
-        try:
-            handler = self.event_handlers[_type]
-        except KeyError:
-            pass
-        else:
-            return handler(message.payload)
-
-        hostname = (message.headers.get('hostname') or
-                    message.payload['hostname'])
-        if hostname != self.hostname:
-            type, event = prepare(message.payload)
-            self.update_state(event)
-        else:
-            self.clock.forward()
-
-
 class Evloop(bootsteps.StartStopStep):
+
     label = 'event loop'
     last = True
 

+ 27 - 0
celery/worker/consumer/control.py

@@ -0,0 +1,27 @@
+from __future__ import absolute_import, unicode_literals
+
+from celery import bootsteps
+from celery.utils.log import get_logger
+
+from celery.worker import pidbox
+
+from .tasks import Tasks
+
+__all__ = ['Control']
+logger = get_logger(__name__)
+
+
+class Control(bootsteps.StartStopStep):
+
+    requires = (Tasks,)
+
+    def __init__(self, c, **kwargs):
+        self.is_green = c.pool is not None and c.pool.is_green
+        self.box = (pidbox.gPidbox if self.is_green else pidbox.Pidbox)(c)
+        self.start = self.box.start
+        self.stop = self.box.stop
+        self.shutdown = self.box.shutdown
+
+    def include_if(self, c):
+        return (c.app.conf.worker_enable_remote_control and
+                c.conninfo.supports_exchange_type('fanout'))

+ 56 - 0
celery/worker/consumer/events.py

@@ -0,0 +1,56 @@
+from __future__ import absolute_import, unicode_literals
+
+from kombu.common import ignore_errors
+
+from celery import bootsteps
+
+from .connection import Connection
+
+__all__ = ['Events']
+
+
+class Events(bootsteps.StartStopStep):
+
+    requires = (Connection,)
+
+    def __init__(self, c, send_events=True,
+                 without_heartbeat=False, without_gossip=False, **kwargs):
+        self.groups = None if send_events else ['worker']
+        self.send_events = (
+            send_events or
+            not without_gossip or
+            not without_heartbeat
+        )
+        c.event_dispatcher = None
+
+    def start(self, c):
+        # flush events sent while connection was down.
+        prev = self._close(c)
+        dis = c.event_dispatcher = c.app.events.Dispatcher(
+            c.connect(), hostname=c.hostname,
+            enabled=self.send_events, groups=self.groups,
+            buffer_group=['task'] if c.hub else None,
+            on_send_buffered=c.on_send_event_buffered if c.hub else None,
+        )
+        if prev:
+            dis.extend_buffer(prev)
+            dis.flush()
+
+    def stop(self, c):
+        pass
+
+    def _close(self, c):
+        if c.event_dispatcher:
+            dispatcher = c.event_dispatcher
+            # remember changes from remote control commands:
+            self.groups = dispatcher.groups
+
+            # close custom connection
+            if dispatcher.connection:
+                ignore_errors(c, dispatcher.connection.close)
+            ignore_errors(c, dispatcher.close)
+            c.event_dispatcher = None
+            return dispatcher
+
+    def shutdown(self, c):
+        self._close(c)

+ 195 - 0
celery/worker/consumer/gossip.py

@@ -0,0 +1,195 @@
+from __future__ import absolute_import, unicode_literals
+
+from collections import defaultdict
+from functools import partial
+from heapq import heappush
+from operator import itemgetter
+
+from kombu import Consumer
+from kombu.async.semaphore import DummyLock
+
+from celery import bootsteps
+from celery.five import values
+from celery.utils.log import get_logger
+from celery.utils.objects import Bunch
+
+from .mingle import Mingle
+
+__all__ = ['Gossip']
+logger = get_logger(__name__)
+debug, info, error = logger.debug, logger.info, logger.error
+
+
+class Gossip(bootsteps.ConsumerStep):
+
+    label = 'Gossip'
+    requires = (Mingle,)
+    _cons_stamp_fields = itemgetter(
+        'id', 'clock', 'hostname', 'pid', 'topic', 'action', 'cver',
+    )
+    compatible_transports = {'amqp', 'redis'}
+
+    def __init__(self, c, without_gossip=False,
+                 interval=5.0, heartbeat_interval=2.0, **kwargs):
+        self.enabled = not without_gossip and self.compatible_transport(c.app)
+        self.app = c.app
+        c.gossip = self
+        self.Receiver = c.app.events.Receiver
+        self.hostname = c.hostname
+        self.full_hostname = '.'.join([self.hostname, str(c.pid)])
+        self.on = Bunch(
+            node_join=set(),
+            node_leave=set(),
+            node_lost=set(),
+        )
+
+        self.timer = c.timer
+        if self.enabled:
+            self.state = c.app.events.State(
+                on_node_join=self.on_node_join,
+                on_node_leave=self.on_node_leave,
+                max_tasks_in_memory=1,
+            )
+            if c.hub:
+                c._mutex = DummyLock()
+            self.update_state = self.state.event
+        self.interval = interval
+        self.heartbeat_interval = heartbeat_interval
+        self._tref = None
+        self.consensus_requests = defaultdict(list)
+        self.consensus_replies = {}
+        self.event_handlers = {
+            'worker.elect': self.on_elect,
+            'worker.elect.ack': self.on_elect_ack,
+        }
+        self.clock = c.app.clock
+
+        self.election_handlers = {
+            'task': self.call_task
+        }
+
+    def compatible_transport(self, app):
+        with app.connection_for_read() as conn:
+            return conn.transport.driver_type in self.compatible_transports
+
+    def election(self, id, topic, action=None):
+        self.consensus_replies[id] = []
+        self.dispatcher.send(
+            'worker-elect',
+            id=id, topic=topic, action=action, cver=1,
+        )
+
+    def call_task(self, task):
+        try:
+            self.app.signature(task).apply_async()
+        except Exception as exc:
+            error('Could not call task: %r', exc, exc_info=1)
+
+    def on_elect(self, event):
+        try:
+            (id_, clock, hostname, pid,
+             topic, action, _) = self._cons_stamp_fields(event)
+        except KeyError as exc:
+            return error('election request missing field %s', exc, exc_info=1)
+        heappush(
+            self.consensus_requests[id_],
+            (clock, '%s.%s' % (hostname, pid), topic, action),
+        )
+        self.dispatcher.send('worker-elect-ack', id=id_)
+
+    def start(self, c):
+        super(Gossip, self).start(c)
+        self.dispatcher = c.event_dispatcher
+
+    def on_elect_ack(self, event):
+        id = event['id']
+        try:
+            replies = self.consensus_replies[id]
+        except KeyError:
+            return  # not for us
+        alive_workers = self.state.alive_workers()
+        replies.append(event['hostname'])
+
+        if len(replies) >= len(alive_workers):
+            _, leader, topic, action = self.clock.sort_heap(
+                self.consensus_requests[id],
+            )
+            if leader == self.full_hostname:
+                info('I won the election %r', id)
+                try:
+                    handler = self.election_handlers[topic]
+                except KeyError:
+                    error('Unknown election topic %r', topic, exc_info=1)
+                else:
+                    handler(action)
+            else:
+                info('node %s elected for %r', leader, id)
+            self.consensus_requests.pop(id, None)
+            self.consensus_replies.pop(id, None)
+
+    def on_node_join(self, worker):
+        debug('%s joined the party', worker.hostname)
+        self._call_handlers(self.on.node_join, worker)
+
+    def on_node_leave(self, worker):
+        debug('%s left', worker.hostname)
+        self._call_handlers(self.on.node_leave, worker)
+
+    def on_node_lost(self, worker):
+        info('missed heartbeat from %s', worker.hostname)
+        self._call_handlers(self.on.node_lost, worker)
+
+    def _call_handlers(self, handlers, *args, **kwargs):
+        for handler in handlers:
+            try:
+                handler(*args, **kwargs)
+            except Exception as exc:
+                error('Ignored error from handler %r: %r',
+                      handler, exc, exc_info=1)
+
+    def register_timer(self):
+        if self._tref is not None:
+            self._tref.cancel()
+        self._tref = self.timer.call_repeatedly(self.interval, self.periodic)
+
+    def periodic(self):
+        workers = self.state.workers
+        dirty = set()
+        for worker in values(workers):
+            if not worker.alive:
+                dirty.add(worker)
+                self.on_node_lost(worker)
+        for worker in dirty:
+            workers.pop(worker.hostname, None)
+
+    def get_consumers(self, channel):
+        self.register_timer()
+        ev = self.Receiver(channel, routing_key='worker.#',
+                           queue_ttl=self.heartbeat_interval)
+        return [Consumer(
+            channel,
+            queues=[ev.queue],
+            on_message=partial(self.on_message, ev.event_from_message),
+            no_ack=True
+        )]
+
+    def on_message(self, prepare, message):
+        _type = message.delivery_info['routing_key']
+
+        # For redis when `fanout_patterns=False` (See Issue #1882)
+        if _type.split('.', 1)[0] == 'task':
+            return
+        try:
+            handler = self.event_handlers[_type]
+        except KeyError:
+            pass
+        else:
+            return handler(message.payload)
+
+        hostname = (message.headers.get('hostname') or
+                    message.payload['hostname'])
+        if hostname != self.hostname:
+            type, event = prepare(message.payload)
+            self.update_state(event)
+        else:
+            self.clock.forward()

+ 30 - 0
celery/worker/consumer/heart.py

@@ -0,0 +1,30 @@
+from __future__ import absolute_import, unicode_literals
+
+from celery import bootsteps
+
+from celery.worker import heartbeat
+
+from .events import Events
+
+__all__ = ['Heart']
+
+
+class Heart(bootsteps.StartStopStep):
+
+    requires = (Events,)
+
+    def __init__(self, c,
+                 without_heartbeat=False, heartbeat_interval=None, **kwargs):
+        self.enabled = not without_heartbeat
+        self.heartbeat_interval = heartbeat_interval
+        c.heart = None
+
+    def start(self, c):
+        c.heart = heartbeat.Heart(
+            c.timer, c.event_dispatcher, self.heartbeat_interval,
+        )
+        c.heart.start()
+
+    def stop(self, c):
+        c.heart = c.heart and c.heart.stop()
+    shutdown = stop

+ 53 - 0
celery/worker/consumer/mingle.py

@@ -0,0 +1,53 @@
+from __future__ import absolute_import, unicode_literals
+
+from operator import itemgetter
+
+from celery import bootsteps
+from celery.five import items, values
+from celery.utils.log import get_logger
+
+from celery.worker.state import revoked
+
+from .events import Events
+
+__all__ = ['Mingle']
+
+MINGLE_GET_FIELDS = itemgetter('clock', 'revoked')
+
+logger = get_logger(__name__)
+info = logger.info
+
+
+class Mingle(bootsteps.StartStopStep):
+
+    label = 'Mingle'
+    requires = (Events,)
+    compatible_transports = {'amqp', 'redis'}
+
+    def __init__(self, c, without_mingle=False, **kwargs):
+        self.enabled = not without_mingle and self.compatible_transport(c.app)
+
+    def compatible_transport(self, app):
+        with app.connection_for_read() as conn:
+            return conn.transport.driver_type in self.compatible_transports
+
+    def start(self, c):
+        info('mingle: searching for neighbors')
+        I = c.app.control.inspect(timeout=1.0, connection=c.connection)
+        replies = I.hello(c.hostname, revoked._data) or {}
+        replies.pop(c.hostname, None)
+        if replies:
+            info('mingle: sync with %s nodes',
+                 len([reply for reply, value in items(replies) if value]))
+            for reply in values(replies):
+                if reply:
+                    try:
+                        other_clock, other_revoked = MINGLE_GET_FIELDS(reply)
+                    except KeyError:  # reply from pre-3.1 worker
+                        pass
+                    else:
+                        c.app.clock.adjust(other_clock)
+                        revoked.update(other_revoked)
+            info('mingle: sync complete')
+        else:
+            info('mingle: all alone')

+ 59 - 0
celery/worker/consumer/tasks.py

@@ -0,0 +1,59 @@
+from __future__ import absolute_import, unicode_literals
+
+from kombu.common import QoS, ignore_errors
+
+from celery import bootsteps
+from celery.utils.log import get_logger
+
+from .mingle import Mingle
+
+__all__ = ['Tasks']
+logger = get_logger(__name__)
+debug = logger.debug
+
+
+class Tasks(bootsteps.StartStopStep):
+
+    requires = (Mingle,)
+
+    def __init__(self, c, **kwargs):
+        c.task_consumer = c.qos = None
+
+    def start(self, c):
+        c.update_strategies()
+
+        # - RabbitMQ 3.3 completely redefines how basic_qos works..
+        # This will detect if the new qos smenatics is in effect,
+        # and if so make sure the 'apply_global' flag is set on qos updates.
+        qos_global = not c.connection.qos_semantics_matches_spec
+
+        # set initial prefetch count
+        c.connection.default_channel.basic_qos(
+            0, c.initial_prefetch_count, qos_global,
+        )
+
+        c.task_consumer = c.app.amqp.TaskConsumer(
+            c.connection, on_decode_error=c.on_decode_error,
+        )
+
+        def set_prefetch_count(prefetch_count):
+            return c.task_consumer.qos(
+                prefetch_count=prefetch_count,
+                apply_global=qos_global,
+            )
+        c.qos = QoS(set_prefetch_count, c.initial_prefetch_count)
+
+    def stop(self, c):
+        if c.task_consumer:
+            debug('Cancelling task consumer...')
+            ignore_errors(c, c.task_consumer.cancel)
+
+    def shutdown(self, c):
+        if c.task_consumer:
+            self.stop(c)
+            debug('Closing consumer channel...')
+            ignore_errors(c, c.task_consumer.close)
+            c.task_consumer = None
+
+    def info(self, c):
+        return {'prefetch_count': c.qos.value if c.qos else 'N/A'}