Browse Source

events.State now supports callbacks for on_node_join/on_node_leave, also fixes #1728

Ask Solem 11 years ago
parent
commit
ef15f051e0
3 changed files with 32 additions and 23 deletions
  1. 21 8
      celery/events/state.py
  2. 5 4
      celery/tests/worker/test_consumer.py
  3. 6 11
      celery/worker/consumer.py

+ 21 - 8
celery/events/state.py

@@ -370,7 +370,8 @@ class State(object):
 
     def __init__(self, callback=None,
                  workers=None, tasks=None, taskheap=None,
-                 max_workers_in_memory=5000, max_tasks_in_memory=10000):
+                 max_workers_in_memory=5000, max_tasks_in_memory=10000,
+                 on_node_join=None, on_node_leave=None):
         self.event_callback = callback
         self.workers = (LRUCache(max_workers_in_memory)
                         if workers is None else workers)
@@ -379,6 +380,8 @@ class State(object):
         self._taskheap = [] if taskheap is None else taskheap
         self.max_workers_in_memory = max_workers_in_memory
         self.max_tasks_in_memory = max_tasks_in_memory
+        self.on_node_join = on_node_join
+        self.on_node_leave = on_node_leave
         self._mutex = threading.Lock()
         self.handlers = {}
         self._seen_types = set()
@@ -451,11 +454,11 @@ class State(object):
 
     def task_event(self, type_, fields):
         """Deprecated, use :meth:`event`."""
-        return self._event(dict(fields, type='-'.join(['task', type_])))
+        return self._event(dict(fields, type='-'.join(['task', type_])))[0]
 
     def worker_event(self, type_, fields):
         """Deprecated, use :meth:`event`."""
-        return self._event(dict(fields, type='-'.join(['worker', type_])))
+        return self._event(dict(fields, type='-'.join(['worker', type_])))[0]
 
     def _create_dispatcher(self):
         get_handler = self.handlers.__getitem__
@@ -470,6 +473,7 @@ class State(object):
         #: an O(n) operation
         max_events_in_heap = self.max_tasks_in_memory * self.heap_multiplier
         add_type = self._seen_types.add
+        on_node_join, on_node_leave = self.on_node_join, self.on_node_leave
         tasks, Task = self.tasks, self.Task
         workers, Worker = self.workers, self.Worker
         # avoid updating LRU entry at getitem
@@ -486,7 +490,7 @@ class State(object):
             except KeyError:
                 pass
             else:
-                return handler(subject, event)
+                return handler(subject, event), subject
 
             if group == 'worker':
                 try:
@@ -497,9 +501,17 @@ class State(object):
                     try:
                         worker, created = get_worker(hostname), False
                     except KeyError:
-                        worker = workers[hostname] = Worker(hostname)
-                    worker.event(subject, timestamp, local_received, event)
-                    return created
+                        if subject == 'offline':
+                            worker, created = None, False
+                        else:
+                            worker = workers[hostname] = Worker(hostname)
+                    if worker:
+                        worker.event(subject, timestamp, local_received, event)
+                    if on_node_join and (created or subject == 'online'):
+                        on_node_join(worker)
+                    if on_node_leave and subject == 'offline':
+                        on_node_leave(worker)
+                    return (worker, created), subject
             elif group == 'task':
                 (uuid, hostname, timestamp,
                  local_received, clock) = tfields(event)
@@ -530,7 +542,7 @@ class State(object):
                 task_name = task.name
                 if task_name is not None:
                     add_type(task_name)
-                return created
+                return (task, created), subject
         return _event
 
     def rebuild_taskheap(self, timetuple=timetuple, heapify=heapify):
@@ -596,4 +608,5 @@ class State(object):
         return self.__class__, (
             self.event_callback, self.workers, self.tasks, None,
             self.max_workers_in_memory, self.max_tasks_in_memory,
+            self.on_node_join, self.on_node_leave,
         )

+ 5 - 4
celery/tests/worker/test_consumer.py

@@ -433,8 +433,13 @@ class test_Gossip(AppCase):
     def test_on_message(self):
         c = self.Consumer()
         g = Gossip(c)
+        self.assertTrue(g.enabled)
         prepare = Mock()
         prepare.return_value = 'worker-online', {}
+        c.app.events.State.assert_called_with(
+            on_node_join=g.on_node_join,
+            on_node_leave=g.on_node_leave,
+        )
         g.update_state = Mock()
         worker = Mock()
         g.on_node_join = Mock()
@@ -450,20 +455,16 @@ class test_Gossip(AppCase):
         g.event_handlers = {}
 
         g.on_message(prepare, message)
-        g.on_node_join.assert_called_with(worker)
 
         message.delivery_info = {'routing_key': 'worker-offline'}
         prepare.return_value = 'worker-offline', {}
         g.on_message(prepare, message)
-        g.on_node_leave.assert_called_with(worker)
 
         message.delivery_info = {'routing_key': 'worker-baz'}
         prepare.return_value = 'worker-baz', {}
         g.update_state.return_value = worker, 0
         g.on_message(prepare, message)
 
-        g.on_node_leave.reset_mock()
         message.headers = {'hostname': g.hostname}
         g.on_message(prepare, message)
-        self.assertFalse(g.on_node_leave.called)
         g.clock.forward.assert_called_with()

+ 6 - 11
celery/worker/consumer.py

@@ -645,10 +645,13 @@ class Gossip(bootsteps.ConsumerStep):
 
         self.timer = c.timer
         if self.enabled:
-            self.state = c.app.events.State()
+            self.state = c.app.events.State(
+                on_node_join=self.on_node_join,
+                on_node_leave=self.on_node_leave,
+            )
             if c.hub:
                 c._mutex = DummyLock()
-            self.update_state = self.state.worker_event
+            self.update_state = self.state.event
         self.interval = interval
         self._tref = None
         self.consensus_requests = defaultdict(list)
@@ -769,15 +772,7 @@ class Gossip(bootsteps.ConsumerStep):
                     message.payload['hostname'])
         if hostname != self.hostname:
             type, event = prepare(message.payload)
-            group, _, subject = type.partition('-')
-            worker, created = self.update_state(subject, event)
-            if subject == 'offline':
-                try:
-                    self.on_node_leave(worker)
-                finally:
-                    self.state.workers.pop(worker.hostname, None)
-            elif created or subject == 'online':
-                self.on_node_join(worker)
+            obj, subject = self.update_state(event)
         else:
             self.clock.forward()