Browse Source

More functional tests

Ask Solem 14 years ago
parent
commit
bc6b598312
2 changed files with 95 additions and 30 deletions
  1. 68 22
      celery/tests/functional/case.py
  2. 27 8
      celery/tests/functional/test_basic.py

+ 68 - 22
celery/tests/functional/case.py

@@ -9,19 +9,25 @@ import unittest2 as unittest
 
 from itertools import count
 
-from celery.task.control import broadcast, ping
+from celery.exceptions import TimeoutError
+from celery.task.control import broadcast, ping, flatten_reply, inspect
 from celery.utils import get_full_cls_name
 
 HOSTNAME = socket.gethostname()
 
+
 def say(msg):
     sys.stderr.write("%s\n" % msg)
 
-def flatten_response(response):
-    flat = {}
-    for item in response:
-        flat.update(item)
-    return flat
+
+def try_while(fun, reason="Timed out", timeout=10, interval=0.5):
+    for iterations in count(0):
+        if iterations * interval >= timeout:
+            raise TimeoutError()
+        ret = fun()
+        if ret:
+            return ret
+
 
 class Worker(object):
     started = False
@@ -50,26 +56,19 @@ class Worker(object):
     def is_alive(self, timeout=1):
         r = ping(destination=[self.hostname],
                  timeout=timeout)
-        return self.hostname in flatten_response(r)
+        return self.hostname in flatten_reply(r)
 
     def wait_until_started(self, timeout=10, interval=0.2):
-        for iteration in count(0):
-            if iteration * interval >= timeout:
-                raise Exception(
-                        "Worker won't start (after %s secs.)" % timeout)
-            if self.is_alive(interval):
-                break
+        try_while(lambda: self.is_alive(interval),
+                "Worker won't start (after %s secs.)" % timeout,
+                interval=0.2, timeout=10)
         say("--WORKER %s IS ONLINE--" % self.hostname)
 
     def ensure_shutdown(self, timeout=10, interval=0.5):
         os.kill(self.pid, signal.SIGTERM)
-        for iteration in count(0):
-            if iteration * interval >= timeout:
-                raise Exception(
-                        "Worker won't shutdown (after %s secs.)" % timeout)
-            broadcast("shutdown", destination=[self.hostname])
-            if not self.is_alive(interval):
-                break
+        try_while(lambda: not self.is_alive(interval),
+                  "Worker won't shutdown (after %s secs.)" % timeout,
+                  timeout=10, interval=0.5)
         say("--WORKER %s IS SHUTDOWN--" % self.hostname)
         self._shutdown_called = True
 
@@ -115,6 +114,53 @@ class WorkerCase(unittest.TestCase):
     def assertWorkerAlive(self, timeout=1):
         self.assertTrue(self.worker.is_alive)
 
-    def my_response(self, response):
-        return flatten_response(response)[self.worker.hostname]
+    def inspect(self, timeout=1):
+        return inspect(self.worker.hostname, timeout=timeout)
 
+    def my_response(self, response):
+        return flatten_reply(response)[self.worker.hostname]
+
+    def is_accepted(self, task_id, interval=0.5):
+        active = self.inspect(timeout=interval).active()
+        if active:
+            for task in active:
+                if task["id"] == task_id:
+                    return True
+        return False
+
+    def is_reserved(self, task_id, interval=0.5):
+        reserved = self.inspect(timeout=interval).reserved()
+        if reserved:
+            for task in reserved:
+                if task["id"] == task_id:
+                    return True
+        return False
+
+    def is_scheduled(self, task_id, interval=0.5):
+        schedule = self.inspect(timeout=interval).scheduled()
+        if schedule:
+            for item in schedule:
+                if item["request"]["id"] == task_id:
+                    return True
+        return False
+
+    def is_received(self, task_id, interval=0.5):
+        return (self.is_reserved(task_id, interval) or
+                self.is_scheduled(task_id, interval) or
+                self.is_accepted(task_id, interval))
+
+
+    def ensure_accepted(self, task_id, interval=0.5, timeout=10):
+        return try_while(lambda: self.is_accepted(task_id, interval),
+                         "Task not accepted within timeout",
+                         interval=0.5, timeout=10)
+
+    def ensure_received(self, task_id, interval=0.5, timeout=10):
+        return try_while(lambda: self.is_received(task_id, interval),
+                        "Task not receied within timeout",
+                        interval=0.5, timeout=10)
+
+    def ensure_scheduled(self, task_id, interval=0.5, timeout=10):
+        return try_while(lambda: self.is_scheduled(task_id, interval),
+                        "Task not scheduled within timeout",
+                        interval=0.5, timeout=10)

+ 27 - 8
celery/tests/functional/test_basic.py

@@ -19,14 +19,33 @@ class test_basic(WorkerCase):
         for result, i in results:
             self.assertEqual(result.get(timeout=10), operator.add(*i))
 
-    def test_dump_active(self):
-        tasks.sleeptask.delay(2)
-        tasks.sleeptask.delay(2)
-        time.sleep(0.2)
-        r = broadcast("dump_active",
-                           arguments={"safe": True}, reply=True)
-        active = self.my_response(r)
+    def test_dump_active(self, sleep=1):
+        r1 = tasks.sleeptask.delay(sleep)
+        r2 = tasks.sleeptask.delay(sleep)
+        self.ensure_accepted(r1.task_id)
+        active = self.inspect().active(safe=True)
         self.assertEqual(len(active), 2)
         self.assertEqual(active[0]["name"], tasks.sleeptask.name)
-        self.assertEqual(active[0]["args"], [2])
+        self.assertEqual(active[0]["args"], [sleep])
+
+    def test_dump_reserved(self, sleep=1):
+        r1 = tasks.sleeptask.delay(sleep)
+        r2 = tasks.sleeptask.delay(sleep)
+        r3 = tasks.sleeptask.delay(sleep)
+        r4 = tasks.sleeptask.delay(sleep)
+        self.ensure_accepted(r1.task_id)
+        reserved = self.inspect().reserved(safe=True)
+        self.assertTrue(reserved)
+        self.assertEqual(reserved[0]["name"], tasks.sleeptask.name)
+        self.assertEqual(reserved[0]["args"], [sleep])
+
+    def test_dump_schedule(self, countdown=1):
+        r1 = tasks.add.apply_async((2, 2), countdown=countdown)
+        r2 = tasks.add.apply_async((2, 2), countdown=countdown)
+        self.ensure_scheduled(r1.task_id, interval=0.1)
+        schedule = self.inspect().scheduled(safe=True)
+        self.assertTrue(schedule)
+        self.assertTrue(len(schedule), 2)
+        self.assertEqual(schedule[0]["request"]["name"], tasks.add.name)
+        self.assertEqual(schedule[0]["request"]["args"], [2, 2])