浏览代码

99% coverage for celery.events.state

Ask Solem 14 年之前
父节点
当前提交
2c41cd2e37
共有 2 个文件被更改,包括 159 次插入9 次删除
  1. 7 6
      celery/events/state.py
  2. 152 3
      celery/tests/test_events_state.py

+ 7 - 6
celery/events/state.py

@@ -229,11 +229,12 @@ class State(object):
 
 
     def worker_event(self, type, fields):
     def worker_event(self, type, fields):
         """Process worker event."""
         """Process worker event."""
-        hostname = fields.pop("hostname")
-        worker = self.get_or_create_worker(hostname)
-        handler = getattr(worker, "on_%s" % type)
-        if handler:
-            handler(**fields)
+        hostname = fields.pop("hostname", None)
+        if hostname:
+            worker = self.get_or_create_worker(hostname)
+            handler = getattr(worker, "on_%s" % type, None)
+            if handler:
+                handler(**fields)
 
 
     def task_event(self, type, fields):
     def task_event(self, type, fields):
         """Process task event."""
         """Process task event."""
@@ -241,7 +242,7 @@ class State(object):
         hostname = fields.pop("hostname")
         hostname = fields.pop("hostname")
         worker = self.get_or_create_worker(hostname)
         worker = self.get_or_create_worker(hostname)
         task = self.get_or_create_task(uuid)
         task = self.get_or_create_task(uuid)
-        handler = getattr(task, "on_%s" % type)
+        handler = getattr(task, "on_%s" % type, None)
         if type == "received":
         if type == "received":
             self.task_count += 1
             self.task_count += 1
         if handler:
         if handler:

+ 152 - 3
celery/tests/test_events_state.py

@@ -1,11 +1,11 @@
-import time
+from time import time
 import unittest2 as unittest
 import unittest2 as unittest
 
 
 from itertools import count
 from itertools import count
 
 
 from celery import states
 from celery import states
 from celery.events import Event
 from celery.events import Event
-from celery.events.state import State, HEARTBEAT_EXPIRE
+from celery.events.state import State, Worker, Task, HEARTBEAT_EXPIRE
 from celery.utils import gen_unique_id
 from celery.utils import gen_unique_id
 
 
 
 
@@ -43,7 +43,7 @@ class ev_worker_online_offline(replay):
 class ev_worker_heartbeats(replay):
 class ev_worker_heartbeats(replay):
     events = [
     events = [
         Event("worker-heartbeat", hostname="utest1",
         Event("worker-heartbeat", hostname="utest1",
-              timestamp=time.time() - HEARTBEAT_EXPIRE * 2),
+              timestamp=time() - HEARTBEAT_EXPIRE * 2),
         Event("worker-heartbeat", hostname="utest1"),
         Event("worker-heartbeat", hostname="utest1"),
     ]
     ]
 
 
@@ -78,8 +78,53 @@ class ev_snapshot(replay):
                       uuid=gen_unique_id(), hostname=worker))
                       uuid=gen_unique_id(), hostname=worker))
 
 
 
 
+class test_Worker(unittest.TestCase):
+
+    def test_survives_missing_timestamp(self):
+        worker = Worker(hostname="foo")
+        worker.on_heartbeat(timestamp=None)
+        self.assertEqual(worker.heartbeats, [])
+
+
+class test_Task(unittest.TestCase):
+
+    def test_info(self):
+        task = Task(uuid="abcdefg",
+                    name="tasks.add",
+                    args="(2, 2)",
+                    kwargs="{}",
+                    retries=2,
+                    result=42,
+                    eta=1,
+                    runtime=0.0001,
+                    expires=1,
+                    exception=1,
+                    received=time() - 10,
+                    started=time() - 8,
+                    succeeded=time())
+        self.assertItemsEqual(list(task._info_fields),
+                              task.info().keys())
+
+        self.assertItemsEqual(list(task._info_fields + ("received", )),
+                              task.info(extra=("received", )))
+
+        self.assertItemsEqual(["args", "kwargs"],
+                              task.info(["args", "kwargs"]).keys())
+
+    def test_ready(self):
+        task = Task(uuid="abcdefg",
+                    name="tasks.add")
+        task.on_received(timestamp=time())
+        self.assertFalse(task.ready)
+        task.on_succeeded(timestamp=time())
+        self.assertTrue(task.ready)
+
+
 class test_State(unittest.TestCase):
 class test_State(unittest.TestCase):
 
 
+    def test_repr(self):
+        self.assertTrue(repr(State()))
+
     def test_worker_online_offline(self):
     def test_worker_online_offline(self):
         r = ev_worker_online_offline(State())
         r = ev_worker_online_offline(State())
         r.next()
         r.next()
@@ -153,6 +198,87 @@ class test_State(unittest.TestCase):
         self.assertEqual(task.timestamp, task.revoked)
         self.assertEqual(task.timestamp, task.revoked)
         self.assertEqual(task.worker.hostname, "utest1")
         self.assertEqual(task.worker.hostname, "utest1")
 
 
+    def test_freeze_thaw__buffering(self):
+        s = State()
+        r = ev_snapshot(s)
+        s.freeze(buffer=True)
+        self.assertTrue(s._buffering)
+
+        r.play()
+        self.assertStateEmpty(s)
+        self.assertTrue(s.buffer)
+
+        s.thaw()
+        self.assertState(s)
+        self.assertFalse(s.buffer)
+
+    def assertStateEmpty(self, state):
+        self.assertFalse(state.tasks)
+        self.assertFalse(state.workers)
+        self.assertFalse(state.event_count)
+        self.assertFalse(state.task_count)
+
+    def assertState(self, state):
+        self.assertTrue(state.tasks)
+        self.assertTrue(state.workers)
+        self.assertTrue(state.event_count)
+        self.assertTrue(state.task_count)
+
+    def test_thaw__no_replay(self):
+        s = State()
+        r = ev_snapshot(s)
+        s.freeze(buffer=True)
+
+        r.play()
+        s.thaw(replay=False)
+        self.assertFalse(s.buffer)
+        self.assertStateEmpty(s)
+
+    def test_freeze_while(self):
+        s = State()
+        r = ev_snapshot(s)
+
+        def work():
+            r.play()
+            self.assertStateEmpty(s)
+
+        s.freeze_while(work)
+        self.assertState(s)
+
+
+    def test_freeze_thaw__not_buffering(self):
+        s = State()
+        r = ev_snapshot(s)
+        s.freeze(buffer=False)
+        self.assertFalse(s._buffering)
+
+        r.play()
+        s.thaw(replay=True)
+        self.assertFalse(s.buffer)
+        self.assertStateEmpty(s)
+
+    def test_clear(self):
+        r = ev_snapshot(State())
+        r.play()
+        self.assertTrue(r.state.event_count)
+        self.assertTrue(r.state.workers)
+        self.assertTrue(r.state.tasks)
+        self.assertTrue(r.state.task_count)
+
+        r.state.clear()
+        self.assertFalse(r.state.event_count)
+        self.assertFalse(r.state.workers)
+        self.assertTrue(r.state.tasks)
+        self.assertFalse(r.state.task_count)
+
+        r.state.clear(False)
+        self.assertFalse(r.state.tasks)
+
+    def test_task_types(self):
+        r = ev_snapshot(State())
+        r.play()
+        self.assertItemsEqual(r.state.task_types(), ["task1", "task2"])
+
     def test_tasks_by_timestamp(self):
     def test_tasks_by_timestamp(self):
         r = ev_snapshot(State())
         r = ev_snapshot(State())
         r.play()
         r.play()
@@ -174,3 +300,26 @@ class test_State(unittest.TestCase):
         r.play()
         r.play()
         self.assertEqual(len(r.state.tasks_by_worker("utest1")), 10)
         self.assertEqual(len(r.state.tasks_by_worker("utest1")), 10)
         self.assertEqual(len(r.state.tasks_by_worker("utest2")), 10)
         self.assertEqual(len(r.state.tasks_by_worker("utest2")), 10)
+
+    def test_survives_unknown_worker_event(self):
+        s = State()
+        s.worker_event("worker-unknown-event-xxx", {"foo": "bar"})
+        s.worker_event("worker-unknown-event-xxx", {"hostname": "xxx",
+                                                    "foo": "bar"})
+
+    def test_survives_unknown_task_event(self):
+        s = State()
+        s.task_event("task-unknown-event-xxx", {"foo": "bar",
+                                                "uuid": "x",
+                                                "hostname": "y"})
+
+    def test_callback(self):
+        scratch = {}
+
+        def callback(state, event):
+            scratch["recv"] = True
+
+        s = State(callback=callback)
+        s.event({"type": "worker-online"})
+        self.assertTrue(scratch.get("recv"))
+