Pārlūkot izejas kodu

Tests for celery.result (AsyncResult, TaskSetResult, EagerResult)

Ask Solem 16 gadi atpakaļ
vecāks
revīzija
073562e940
2 mainītis faili ar 195 papildinājumiem un 1 dzēšanām
  1. 159 0
      celery/tests/test_result.py
  2. 36 1
      celery/tests/test_task.py

+ 159 - 0
celery/tests/test_result.py

@@ -0,0 +1,159 @@
+import unittest
+from celery.backends import default_backend
+from celery.result import AsyncResult
+from celery.result import TaskSetResult
+from celery.utils import gen_unique_id
+
+
+def mock_task(name, status, result):
+    return dict(id=gen_unique_id(), name=name, status=status, result=result)
+
+
+def save_result(task):
+    if task["status"] == "DONE":
+        default_backend.mark_as_done(task["id"], task["result"])
+    else:
+        default_backend.mark_as_failure(task["id"], task["result"])
+
+
+def make_mock_taskset(size=10):
+    tasks = [mock_task("ts%d" % i, "DONE", i) for i in xrange(size)]
+    [save_result(task) for task in tasks]
+    return [AsyncResult(task["id"]) for task in tasks]
+
+
+class TestAsyncResult(unittest.TestCase):
+
+    def setUp(self):
+        self.task1 = mock_task("task1", "DONE", "the")
+        self.task2 = mock_task("task2", "DONE", "quick")
+        self.task3 = mock_task("task3", "FAILURE", KeyError("brown"))
+
+        for task in (self.task1, self.task2, self.task3):
+            save_result(task)
+
+    def test_is_done(self):
+        ok_res = AsyncResult(self.task1["id"])
+        nok_res = AsyncResult(self.task3["id"])
+
+        self.assertTrue(ok_res.is_done())
+        self.assertFalse(nok_res.is_done())
+    
+    def test_sucessful(self):
+        ok_res = AsyncResult(self.task1["id"])
+        nok_res = AsyncResult(self.task3["id"])
+
+        self.assertTrue(ok_res.successful())
+        self.assertFalse(nok_res.successful())
+       
+    def test_str(self):
+        ok_res = AsyncResult(self.task1["id"])
+        ok2_res = AsyncResult(self.task2["id"])
+        nok_res = AsyncResult(self.task3["id"])
+        self.assertEquals(str(ok_res), self.task1["id"])
+        self.assertEquals(str(ok2_res), self.task2["id"])
+        self.assertEquals(str(nok_res), self.task3["id"])
+    
+    def test_repr(self):
+        ok_res = AsyncResult(self.task1["id"])
+        ok2_res = AsyncResult(self.task2["id"])
+        nok_res = AsyncResult(self.task3["id"])
+        self.assertEquals(repr(ok_res), "<AsyncResult: %s>" % (
+                self.task1["id"]))
+        self.assertEquals(repr(ok2_res), "<AsyncResult: %s>" % (
+                self.task2["id"]))
+        self.assertEquals(repr(nok_res), "<AsyncResult: %s>" % (
+                self.task3["id"]))
+
+    def test_get(self):
+        ok_res = AsyncResult(self.task1["id"])
+        ok2_res = AsyncResult(self.task2["id"])
+        nok_res = AsyncResult(self.task3["id"])
+
+        self.assertEquals(ok_res.get(), "the")
+        self.assertEquals(ok2_res.get(), "quick")
+        self.assertRaises(KeyError, nok_res.get)
+
+    def test_ready(self):
+        oks = (AsyncResult(self.task1["id"]),
+               AsyncResult(self.task2["id"]),
+               AsyncResult(self.task3["id"]))
+        [self.assertTrue(ok.ready()) for ok in oks]
+
+
+class TestTaskSetResult(unittest.TestCase):
+
+    def setUp(self):
+        self.size = 10
+        self.ts = TaskSetResult(gen_unique_id(), make_mock_taskset(self.size))
+
+    def test_total(self):
+        self.assertEquals(self.ts.total, self.size)
+
+    def test_itersubtasks(self):
+
+        it = self.ts.itersubtasks()
+
+        for i, t in enumerate(it):
+            self.assertEquals(t.get(), i)
+    
+    def test___iter__(self):
+
+        it = iter(self.ts)
+
+        results = sorted(list(it))
+        self.assertEquals(results, list(xrange(self.size)))
+
+    def test_join(self):
+        joined = self.ts.join()
+        self.assertEquals(joined, list(xrange(self.size)))
+
+    def test_successful(self):
+        self.assertTrue(self.ts.successful())
+    
+    def test_failed(self):
+        self.assertFalse(self.ts.failed())
+    
+    def test_waiting(self):
+        self.assertFalse(self.ts.waiting())
+
+    def test_ready(self):
+        self.assertTrue(self.ts.ready())
+    
+    def test_completed_count(self):
+        self.assertEquals(self.ts.completed_count(), self.ts.total)
+
+
+class TestFailedTaskSetResult(TestTaskSetResult):
+
+    def setUp(self):
+        self.size = 11
+        subtasks = make_mock_taskset(10)
+        failed = mock_task("ts11", "FAILED", KeyError("Baz"))
+        save_result(failed)
+        failed_res = AsyncResult(failed["id"])
+        self.ts = TaskSetResult(gen_unique_id(), subtasks + [failed_res])
+    
+    def test_itersubtasks(self):
+
+        it = self.ts.itersubtasks()
+
+        for i in xrange(self.size - 1):
+            t = it.next()
+            self.assertEquals(t.get(), i)
+        self.assertRaises(KeyError, it.next().get)
+
+    def test_completed_count(self):
+        self.assertEquals(self.ts.completed_count(), self.ts.total - 1)
+
+    def test___iter__(self):
+        pass
+
+    def test_join(self):
+        self.assertRaises(KeyError, self.ts.join)
+    
+    def test_successful(self):
+        self.assertFalse(self.ts.successful())
+    
+    def test_failed(self):
+        self.assertTrue(self.ts.failed())

+ 36 - 1
celery/tests/test_task.py

@@ -7,6 +7,7 @@ from celery import task
 from celery import registry
 from celery.log import setup_logger
 from celery import messaging
+from celery.result import EagerResult
 from celery.backends import default_backend
 
 
@@ -23,9 +24,17 @@ class IncrementCounterTask(task.Task):
     name = "c.unittest.increment_counter_task"
     count = 0
 
-    def run(self, increment_by, **kwargs):
+    def run(self, increment_by=1, **kwargs):
         increment_by = increment_by or 1
         self.__class__.count += increment_by
+        return self.__class__.count
+
+
+class RaisingTask(task.Task):
+    name = "c.unittest.raising_task"
+
+    def run(self, **kwargs):
+        raise KeyError("foo")
 
 
 class TestCeleryTasks(unittest.TestCase):
@@ -107,6 +116,7 @@ class TestCeleryTasks(unittest.TestCase):
 class TestTaskSet(unittest.TestCase):
 
     def test_counter_taskset(self):
+        IncrementCounterTask.count = 0
         ts = task.TaskSet(IncrementCounterTask, [
             [[], {}],
             [[], {"increment_by": 2}],
@@ -134,3 +144,28 @@ class TestTaskSet(unittest.TestCase):
             IncrementCounterTask().run(
                     increment_by=m.get("kwargs", {}).get("increment_by"))
         self.assertEquals(IncrementCounterTask.count, sum(xrange(1, 10)))
+
+
+class TestTaskApply(unittest.TestCase):
+
+    def test_apply(self):
+        IncrementCounterTask.count = 0
+
+        e = IncrementCounterTask.apply()
+        self.assertTrue(isinstance(e, EagerResult))
+        self.assertEquals(e.get(), 1)
+        
+        e = IncrementCounterTask.apply(args=[1])
+        self.assertEquals(e.get(), 2)
+        
+        e = IncrementCounterTask.apply(kwargs={"increment_by": 4})
+        self.assertEquals(e.get(), 6)
+
+        self.assertTrue(e.is_done())
+        self.assertTrue(e.is_ready())
+        self.assertTrue(repr(e).startswith("<EagerResult:"))
+
+        f = RaisingTask.apply()
+        self.assertTrue(f.is_ready())
+        self.assertFalse(f.is_done())
+        self.assertRaises(KeyError, f.get)