Переглянути джерело

[gossip] WIP leader election implementation

Ask Solem 12 роки тому
батько
коміт
fc3529fab0

+ 6 - 1
celery/app/control.py

@@ -93,7 +93,7 @@ class Control(object):
     def __init__(self, app=None):
     def __init__(self, app=None):
         self.app = app_or_default(app)
         self.app = app_or_default(app)
         self.mailbox = self.Mailbox('celery',
         self.mailbox = self.Mailbox('celery',
-                type='fanout', clock=self.app.clock)
+                type='fanout', )#clock=self.app.clock)
 
 
     @cached_property
     @cached_property
     def inspect(self):
     def inspect(self):
@@ -112,6 +112,11 @@ class Control(object):
             return self.app.amqp.TaskConsumer(conn).purge()
             return self.app.amqp.TaskConsumer(conn).purge()
     discard_all = purge
     discard_all = purge
 
 
+    def election(self, id, topic, action=None, connection=None):
+        self.broadcast('election', connection=connection, arguments={
+            'id': id, 'topic': topic, 'action': action,
+        })
+
     def revoke(self, task_id, destination=None, terminate=False,
     def revoke(self, task_id, destination=None, terminate=False,
             signal='SIGTERM', **kwargs):
             signal='SIGTERM', **kwargs):
         """Tell all (or specific) workers to revoke a task by id.
         """Tell all (or specific) workers to revoke a task by id.

+ 4 - 0
celery/backends/amqrpc.py

@@ -30,6 +30,10 @@ class AMQRPCBackend(amqp.AMQPBackend):
 
 
     def on_task_call(self, producer, task_id):
     def on_task_call(self, producer, task_id):
         maybe_declare(self.binding(producer.channel), retry=True)
         maybe_declare(self.binding(producer.channel), retry=True)
+        return self.extra_properties
+
+    @property
+    def extra_properties(self):
         return {'reply_to': self.oid}
         return {'reply_to': self.oid}
 
 
     def _create_binding(self, task_id):
     def _create_binding(self, task_id):

+ 12 - 0
celery/canvas.py

@@ -188,6 +188,18 @@ class Signature(dict):
         args, kwargs, _ = self._merge(args, kwargs, {})
         args, kwargs, _ = self._merge(args, kwargs, {})
         return reprcall(self['task'], args, kwargs)
         return reprcall(self['task'], args, kwargs)
 
 
+    def election(self):
+        type = self.type
+        app = type.app
+        tid = self.options.get('task_id') or uuid()
+
+        with app.producer_or_acquire(None) as P:
+            props = type.backend.on_task_call(P, tid)
+            print('PROPS: %r' % (props, ))
+            app.control.election(tid, 'task', self.clone(task_id=tid, **props),
+                                 connection=P.connection)
+            return type.AsyncResult(tid)
+
     def __repr__(self):
     def __repr__(self):
         return self.reprcall()
         return self.reprcall()
 
 

+ 82 - 3
celery/worker/consumer.py

@@ -14,7 +14,10 @@ import kombu
 import logging
 import logging
 import socket
 import socket
 
 
+from collections import defaultdict
 from functools import partial
 from functools import partial
+from heapq import heappush
+from operator import itemgetter
 
 
 from kombu.common import QoS, ignore_errors
 from kombu.common import QoS, ignore_errors
 from kombu.syn import _detect_environment
 from kombu.syn import _detect_environment
@@ -22,6 +25,7 @@ from kombu.utils.encoding import safe_repr
 
 
 from celery import bootsteps
 from celery import bootsteps
 from celery.app import app_or_default
 from celery.app import app_or_default
+from celery.canvas import subtask
 from celery.task.trace import build_tracer
 from celery.task.trace import build_tracer
 from celery.utils.timer2 import default_timer, to_timestamp
 from celery.utils.timer2 import default_timer, to_timestamp
 from celery.utils.functional import noop
 from celery.utils.functional import noop
@@ -475,17 +479,82 @@ class Agent(bootsteps.StartStopStep):
 
 
 class Gossip(bootsteps.ConsumerStep):
 class Gossip(bootsteps.ConsumerStep):
     label = 'gossip'
     label = 'gossip'
-    requires = (Connection, )
+    requires = (Events, )
+    _cons_stamp_fields = itemgetter(
+        'clock', 'hostname', 'pid', 'topic', 'action',
+    )
 
 
     def __init__(self, c, interval=5.0, **kwargs):
     def __init__(self, c, interval=5.0, **kwargs):
+        self.app = c.app
+        c.gossip = self
         self.Receiver = c.app.events.Receiver
         self.Receiver = c.app.events.Receiver
         self.hostname = c.hostname
         self.hostname = c.hostname
 
 
         self.timer = c.timer
         self.timer = c.timer
-        self.state = c.gossip = c.app.events.State()
+        self.state = c.app.events.State()
         self.interval = interval
         self.interval = interval
         self._tref = None
         self._tref = None
+        self.consensus_requests = defaultdict(list)
+        self.consensus_replies = {}
         self.update_state = self.state.worker_event
         self.update_state = self.state.worker_event
+        self.event_handlers = {
+            'worker.elect': self.on_elect,
+            'worker.elect.ack': self.on_elect_ack,
+        }
+        self.clock = c.app.clock
+
+        self.election_handlers = {
+            'task': self.call_task
+        }
+
+    def election(self, id, topic, action=None):
+        self.consensus_replies[id] = []
+        self.dispatcher.send('worker-elect', id=id, topic=topic, action=action)
+
+    def call_task(self, task):
+        try:
+            X = subtask(task)
+            X.apply_async()
+        except Exception as exc:
+            error('Could not call task: %r', exc, exc_info=1)
+
+    def on_elect(self, event):
+        id = event['id']
+        self.dispatcher.send('worker-elect-ack', id=id)
+        clock, hostname, pid, topic, action = self._cons_stamp_fields(event)
+        heappush(self.consensus_requests[id],
+            (clock, '%s.%s' % (hostname, pid), topic, action),
+        )
+
+    def start(self, c):
+        super(Gossip, self).start(c)
+        self.dispatcher = c.event_dispatcher
+
+    def on_elect_ack(self, event):
+        id = event['id']
+        try:
+            replies = self.consensus_replies[id]
+        except KeyError:
+            return
+        alive_workers = self.state.alive_workers()
+        replies.append(event['hostname'])
+
+        if len(replies) >= len(alive_workers):
+            _, leader, topic, action = self.lock.sort_heap(
+                self.consensus_requests[id],
+            )
+            if leader == self.hostname:
+                print('I won the election %r' % (id, ))
+                try:
+                    handler = self.election_handlers[topic]
+                except KeyError:
+                    error('Unknown election topic %r' % (topic, ), exc_info=1)
+                else:
+                    handler(action)
+            else:
+                print('Node %s elected for %r' % (leader, id))
+            self.consensus_requests.pop(id, None)
+            self.consensus_replies.pop(id, None)
 
 
     def on_node_join(self, worker):
     def on_node_join(self, worker):
         info('{0.hostname} joined the party'.format(worker))
         info('{0.hostname} joined the party'.format(worker))
@@ -494,7 +563,7 @@ class Gossip(bootsteps.ConsumerStep):
         info('{0.hostname} left'.format(worker))
         info('{0.hostname} left'.format(worker))
 
 
     def on_node_lost(self, worker):
     def on_node_lost(self, worker):
-        warning('{0.hostname} went missing!')
+        warn('{0.hostname} went missing!')
 
 
     def register_timer(self):
     def register_timer(self):
         if self._tref is not None:
         if self._tref is not None:
@@ -518,6 +587,14 @@ class Gossip(bootsteps.ConsumerStep):
                     no_ack=True)]
                     no_ack=True)]
 
 
     def on_message(self, prepare, message):
     def on_message(self, prepare, message):
+        _type = message.delivery_info['routing_key']
+        try:
+            handler = self.event_handlers[_type]
+        except KeyError:
+            pass
+        else:
+            return handler(message.payload)
+
         hostname = (message.headers.get('hostname') or
         hostname = (message.headers.get('hostname') or
                     message.payload['hostname'])
                     message.payload['hostname'])
         if hostname != self.hostname:
         if hostname != self.hostname:
@@ -531,6 +608,8 @@ class Gossip(bootsteps.ConsumerStep):
                     self.state.workers.pop(worker.hostname, None)
                     self.state.workers.pop(worker.hostname, None)
             elif created or subject == 'online':
             elif created or subject == 'online':
                 self.on_node_join(worker)
                 self.on_node_join(worker)
+        else:
+            self.clock.forward()
 
 
 
 
 class Evloop(bootsteps.StartStopStep):
 class Evloop(bootsteps.StartStopStep):

+ 6 - 0
celery/worker/control.py

@@ -287,3 +287,9 @@ def active_queues(panel):
 @Panel.register
 @Panel.register
 def dump_conf(panel, **kwargs):
 def dump_conf(panel, **kwargs):
     return jsonify(dict(panel.app.conf))
     return jsonify(dict(panel.app.conf))
+
+
+
+@Panel.register
+def election(panel, id, topic, action=None, **kwargs):
+    panel.consumer.gossip.election(id, topic, action)