Browse Source

Unittests testing celery.pool.TaskPool

Ask Solem 16 years ago
parent
commit
fd55e30655
2 changed files with 114 additions and 0 deletions
  1. 2 0
      celery/pool.py
  2. 112 0
      celery/tests/test_pool.py

+ 2 - 0
celery/pool.py

@@ -52,6 +52,8 @@ class TaskPool(object):
     def stop(self):
         """Terminate the pool."""
         self._pool.terminate()
+        self._processes = {}
+        self._pool = None
 
     def apply_async(self, target, args=None, kwargs=None, callbacks=None,
             errbacks=None, on_acknowledge=None, meta=None):

+ 112 - 0
celery/tests/test_pool.py

@@ -0,0 +1,112 @@
+import unittest
+import logging
+import itertools
+import time
+from celery.pool import TaskPool
+from celery.datastructures import ExceptionInfo
+import sys
+
+
+def do_something(i):
+    return i * i
+
+
+def long_something():
+    import time
+    time.sleep(1)
+
+
+def raise_something(i):
+    try:
+        raise KeyError("FOO EXCEPTION")
+    except KeyError:
+        return ExceptionInfo(sys.exc_info())
+
+
+class TestTaskPool(unittest.TestCase):
+
+    def test_attrs(self):
+        p = TaskPool(limit=2)
+        self.assertEquals(p.limit, 2)
+        self.assertTrue(isinstance(p.logger, logging.Logger))
+        self.assertEquals(p._processed_total, 0)
+        self.assertEquals(p._process_counter.next(), 1)
+        self.assertTrue(p._pool is None)
+        self.assertTrue(p._processes is None)
+
+    def x_start_stop(self):
+        p = TaskPool(limit=2)
+        p.start()
+        self.assertTrue(p._pool)
+        self.assertTrue(isinstance(p._processes, dict))
+        p.stop()
+        self.assertTrue(p._pool is None)
+        self.assertFalse(p._processes)
+
+    def test_apply(self):
+        p = TaskPool(limit=2)
+        p.start()
+        scratchpad = {}
+        proc_counter = itertools.count().next
+
+        def mycallback(ret_value, meta):
+            process = proc_counter()
+            scratchpad[process] = {}
+            scratchpad[process]["ret_value"] = ret_value
+            scratchpad[process]["meta"] = meta
+
+        myerrback = mycallback
+
+        res = p.apply_async(do_something, args=[10], callbacks=[mycallback],
+                            meta={"foo": "bar"})
+        res2 = p.apply_async(raise_something, args=[10], errbacks=[myerrback],
+                            meta={"foo2": "bar2"})
+        res3 = p.apply_async(do_something, args=[20], callbacks=[mycallback],
+                            meta={"foo3": "bar3"})
+
+        self.assertEquals(res.get(), 100)
+        time.sleep(0.5)
+        self.assertTrue(scratchpad.get(0))
+        self.assertEquals(scratchpad[0]["ret_value"], 100)
+        self.assertEquals(scratchpad[0]["meta"], {"foo": "bar"})
+
+        self.assertTrue(isinstance(res2.get(), ExceptionInfo))
+        self.assertTrue(scratchpad.get(1))
+        time.sleep(0.5)
+        #self.assertEquals(scratchpad[1]["ret_value"], "FOO")
+        self.assertTrue(isinstance(scratchpad[1]["ret_value"],
+                          ExceptionInfo))
+        self.assertEquals(scratchpad[1]["ret_value"].exception.args,
+                          ("FOO EXCEPTION", ))
+        self.assertEquals(scratchpad[1]["meta"], {"foo2": "bar2"})
+
+        self.assertEquals(res3.get(), 400)
+        time.sleep(0.5)
+        self.assertTrue(scratchpad.get(2))
+        self.assertEquals(scratchpad[2]["ret_value"], 400)
+        self.assertEquals(scratchpad[2]["meta"], {"foo3": "bar3"})
+        
+        res3 = p.apply_async(do_something, args=[30], callbacks=[mycallback],
+                            meta={"foo4": "bar4"})
+
+        self.assertEquals(res3.get(), 900)
+        time.sleep(0.5)
+        self.assertTrue(scratchpad.get(3))
+        self.assertEquals(scratchpad[3]["ret_value"], 900)
+        self.assertEquals(scratchpad[3]["meta"], {"foo4": "bar4"})
+
+        p.stop()
+
+    def test_is_full(self):
+        p = TaskPool(2)
+        p.start()
+        self.assertFalse(p.full())
+        results = [p.apply_async(long_something) for i in xrange(4)]
+        self.assertTrue(p.full())
+        p.stop()
+
+    def test_get_worker_pids(self):
+        p = TaskPool(5)
+        p.start()
+        self.assertEquals(len(p.get_worker_pids()), 5)
+        p.stop()