فهرست منبع

99% coverage for celery.events.state

Ask Solem 14 سال پیش
والد
کامیت
2c41cd2e37
2فایلهای تغییر یافته به همراه159 افزوده شده و 9 حذف شده
  1. 7 6
      celery/events/state.py
  2. 152 3
      celery/tests/test_events_state.py

+ 7 - 6
celery/events/state.py

@@ -229,11 +229,12 @@ class State(object):
 
     def worker_event(self, type, fields):
         """Process worker event."""
-        hostname = fields.pop("hostname")
-        worker = self.get_or_create_worker(hostname)
-        handler = getattr(worker, "on_%s" % type)
-        if handler:
-            handler(**fields)
+        hostname = fields.pop("hostname", None)
+        if hostname:
+            worker = self.get_or_create_worker(hostname)
+            handler = getattr(worker, "on_%s" % type, None)
+            if handler:
+                handler(**fields)
 
     def task_event(self, type, fields):
         """Process task event."""
@@ -241,7 +242,7 @@ class State(object):
         hostname = fields.pop("hostname")
         worker = self.get_or_create_worker(hostname)
         task = self.get_or_create_task(uuid)
-        handler = getattr(task, "on_%s" % type)
+        handler = getattr(task, "on_%s" % type, None)
         if type == "received":
             self.task_count += 1
         if handler:

+ 152 - 3
celery/tests/test_events_state.py

@@ -1,11 +1,11 @@
-import time
+from time import time
 import unittest2 as unittest
 
 from itertools import count
 
 from celery import states
 from celery.events import Event
-from celery.events.state import State, HEARTBEAT_EXPIRE
+from celery.events.state import State, Worker, Task, HEARTBEAT_EXPIRE
 from celery.utils import gen_unique_id
 
 
@@ -43,7 +43,7 @@ class ev_worker_online_offline(replay):
 class ev_worker_heartbeats(replay):
     events = [
         Event("worker-heartbeat", hostname="utest1",
-              timestamp=time.time() - HEARTBEAT_EXPIRE * 2),
+              timestamp=time() - HEARTBEAT_EXPIRE * 2),
         Event("worker-heartbeat", hostname="utest1"),
     ]
 
@@ -78,8 +78,53 @@ class ev_snapshot(replay):
                       uuid=gen_unique_id(), hostname=worker))
 
 
+class test_Worker(unittest.TestCase):
+
+    def test_survives_missing_timestamp(self):
+        worker = Worker(hostname="foo")
+        worker.on_heartbeat(timestamp=None)
+        self.assertEqual(worker.heartbeats, [])
+
+
+class test_Task(unittest.TestCase):
+
+    def test_info(self):
+        task = Task(uuid="abcdefg",
+                    name="tasks.add",
+                    args="(2, 2)",
+                    kwargs="{}",
+                    retries=2,
+                    result=42,
+                    eta=1,
+                    runtime=0.0001,
+                    expires=1,
+                    exception=1,
+                    received=time() - 10,
+                    started=time() - 8,
+                    succeeded=time())
+        self.assertItemsEqual(list(task._info_fields),
+                              task.info().keys())
+
+        self.assertItemsEqual(list(task._info_fields + ("received", )),
+                              task.info(extra=("received", )))
+
+        self.assertItemsEqual(["args", "kwargs"],
+                              task.info(["args", "kwargs"]).keys())
+
+    def test_ready(self):
+        task = Task(uuid="abcdefg",
+                    name="tasks.add")
+        task.on_received(timestamp=time())
+        self.assertFalse(task.ready)
+        task.on_succeeded(timestamp=time())
+        self.assertTrue(task.ready)
+
+
 class test_State(unittest.TestCase):
 
+    def test_repr(self):
+        self.assertTrue(repr(State()))
+
     def test_worker_online_offline(self):
         r = ev_worker_online_offline(State())
         r.next()
@@ -153,6 +198,87 @@ class test_State(unittest.TestCase):
         self.assertEqual(task.timestamp, task.revoked)
         self.assertEqual(task.worker.hostname, "utest1")
 
+    def test_freeze_thaw__buffering(self):
+        s = State()
+        r = ev_snapshot(s)
+        s.freeze(buffer=True)
+        self.assertTrue(s._buffering)
+
+        r.play()
+        self.assertStateEmpty(s)
+        self.assertTrue(s.buffer)
+
+        s.thaw()
+        self.assertState(s)
+        self.assertFalse(s.buffer)
+
+    def assertStateEmpty(self, state):
+        self.assertFalse(state.tasks)
+        self.assertFalse(state.workers)
+        self.assertFalse(state.event_count)
+        self.assertFalse(state.task_count)
+
+    def assertState(self, state):
+        self.assertTrue(state.tasks)
+        self.assertTrue(state.workers)
+        self.assertTrue(state.event_count)
+        self.assertTrue(state.task_count)
+
+    def test_thaw__no_replay(self):
+        s = State()
+        r = ev_snapshot(s)
+        s.freeze(buffer=True)
+
+        r.play()
+        s.thaw(replay=False)
+        self.assertFalse(s.buffer)
+        self.assertStateEmpty(s)
+
+    def test_freeze_while(self):
+        s = State()
+        r = ev_snapshot(s)
+
+        def work():
+            r.play()
+            self.assertStateEmpty(s)
+
+        s.freeze_while(work)
+        self.assertState(s)
+
+
+    def test_freeze_thaw__not_buffering(self):
+        s = State()
+        r = ev_snapshot(s)
+        s.freeze(buffer=False)
+        self.assertFalse(s._buffering)
+
+        r.play()
+        s.thaw(replay=True)
+        self.assertFalse(s.buffer)
+        self.assertStateEmpty(s)
+
+    def test_clear(self):
+        r = ev_snapshot(State())
+        r.play()
+        self.assertTrue(r.state.event_count)
+        self.assertTrue(r.state.workers)
+        self.assertTrue(r.state.tasks)
+        self.assertTrue(r.state.task_count)
+
+        r.state.clear()
+        self.assertFalse(r.state.event_count)
+        self.assertFalse(r.state.workers)
+        self.assertTrue(r.state.tasks)
+        self.assertFalse(r.state.task_count)
+
+        r.state.clear(False)
+        self.assertFalse(r.state.tasks)
+
+    def test_task_types(self):
+        r = ev_snapshot(State())
+        r.play()
+        self.assertItemsEqual(r.state.task_types(), ["task1", "task2"])
+
     def test_tasks_by_timestamp(self):
         r = ev_snapshot(State())
         r.play()
@@ -174,3 +300,26 @@ class test_State(unittest.TestCase):
         r.play()
         self.assertEqual(len(r.state.tasks_by_worker("utest1")), 10)
         self.assertEqual(len(r.state.tasks_by_worker("utest2")), 10)
+
+    def test_survives_unknown_worker_event(self):
+        s = State()
+        s.worker_event("worker-unknown-event-xxx", {"foo": "bar"})
+        s.worker_event("worker-unknown-event-xxx", {"hostname": "xxx",
+                                                    "foo": "bar"})
+
+    def test_survives_unknown_task_event(self):
+        s = State()
+        s.task_event("task-unknown-event-xxx", {"foo": "bar",
+                                                "uuid": "x",
+                                                "hostname": "y"})
+
+    def test_callback(self):
+        scratch = {}
+
+        def callback(state, event):
+            scratch["recv"] = True
+
+        s = State(callback=callback)
+        s.event({"type": "worker-online"})
+        self.assertTrue(scratch.get("recv"))
+