Ask Solem 11 years ago
parent
commit
40e1a00e53
2 changed files with 37 additions and 3 deletions
  1. 6 2
      celery/events/state.py
  2. 31 1
      celery/tests/events/test_state.py

+ 6 - 2
celery/events/state.py

@@ -516,14 +516,18 @@ class State(object):
                 except KeyError:
                     pass
                 else:
+                    is_offline = subject == 'offline'
                     try:
                         worker, created = get_worker(hostname), False
                     except KeyError:
-                        worker = workers[hostname] = Worker(hostname)
+                        if is_offline:
+                            worker, created = Worker(hostname), False
+                        else:
+                            worker = workers[hostname] = Worker(hostname)
                     worker.event(subject, timestamp, local_received, event)
                     if on_node_join and (created or subject == 'online'):
                         on_node_join(worker)
-                    if on_node_leave and subject == 'offline':
+                    if on_node_leave and is_offline:
                         on_node_leave(worker)
                     return (worker, created), subject
             elif group == 'task':

+ 31 - 1
celery/tests/events/test_state.py

@@ -18,7 +18,7 @@ from celery.events.state import (
 )
 from celery.five import range
 from celery.utils import uuid
-from celery.tests.case import AppCase, patch
+from celery.tests.case import AppCase, Mock, patch
 
 try:
     Decimal(2.6)
@@ -487,6 +487,36 @@ class test_State(AppCase):
             'foo': 'bar',
         })
 
+    def test_survives_unknown_worker_leaving(self):
+        s = State(on_node_leave=Mock(name='on_node_leave'))
+        (worker, created), subject = s.event({
+            'type': 'worker-offline',
+            'hostname': 'unknown@vandelay.com',
+            'timestamp': time(),
+            'local_received': time(),
+            'clock': 301030134894833,
+        })
+        self.assertEqual(worker, Worker('unknown@vandelay.com'))
+        self.assertFalse(created)
+        self.assertEqual(subject, 'offline')
+        self.assertNotIn('unknown@vandelay.com', s.workers)
+        s.on_node_leave.assert_called_with(worker)
+
+    def test_on_node_join_callback(self):
+        s = State(on_node_join=Mock(name='on_node_join'))
+        (worker, created), subject = s.event({
+            'type': 'worker-online',
+            'hostname': 'george@vandelay.com',
+            'timestamp': time(),
+            'local_received': time(),
+            'clock': 34314,
+        })
+        self.assertTrue(worker)
+        self.assertTrue(created)
+        self.assertEqual(subject, 'online')
+        self.assertIn('george@vandelay.com', s.workers)
+        s.on_node_join.assert_called_with(worker)
+
     def test_survives_unknown_task_event(self):
         s = State()
         s.event(