Browse Source

97% Coverage for celery.result

Ask Solem 14 years ago
parent
commit
9f6dcc7ad1
3 changed files with 91 additions and 15 deletions
  1. 14 13
      celery/result.py
  2. 1 1
      celery/task/base.py
  3. 76 1
      celery/tests/test_result.py

+ 14 - 13
celery/result.py

@@ -412,15 +412,15 @@ class EagerResult(BaseAsyncResult):
     """Result that we know has already been executed."""
     TimeoutError = TimeoutError
 
-    def __init__(self, task_id, ret_value, status, traceback=None):
+    def __init__(self, task_id, ret_value, state, traceback=None):
         self.task_id = task_id
         self._result = ret_value
-        self._status = status
+        self._state = state
         self._traceback = traceback
 
     def successful(self):
         """Returns :const:`True` if the task executed without failure."""
-        return self.status == states.SUCCESS
+        return self.state == states.SUCCESS
 
     def ready(self):
         """Returns :const:`True` if the task has been executed."""
@@ -428,24 +428,22 @@ class EagerResult(BaseAsyncResult):
 
     def wait(self, timeout=None):
         """Wait until the task has been executed and return its result."""
-        if self.status == states.SUCCESS:
+        if self.state == states.SUCCESS:
             return self.result
-        elif self.status in states.PROPAGATE_STATES:
+        elif self.state in states.PROPAGATE_STATES:
             raise self.result
 
     def revoke(self):
-        self._status = states.REVOKED
+        self._state = states.REVOKED
+
+    def __repr__(self):
+        return "<EagerResult: %s>" % self.task_id
 
     @property
     def result(self):
         """The tasks return value"""
         return self._result
 
-    @property
-    def status(self):
-        """The tasks status (alias to :attr:`state`)."""
-        return self._status
-
     @property
     def state(self):
         """The tasks state."""
@@ -456,5 +454,8 @@ class EagerResult(BaseAsyncResult):
         """The traceback if the task failed."""
         return self._traceback
 
-    def __repr__(self):
-        return "<EagerResult: %s>" % self.task_id
+    @property
+    def status(self):
+        """The tasks status (alias to :attr:`state`)."""
+        return self._state
+

+ 1 - 1
celery/task/base.py

@@ -289,7 +289,7 @@ class BaseTask(object):
         return self.app.log.setup_task_logger(loglevel=loglevel,
                                               logfile=logfile,
                                               propagate=propagate,
-                                              task_kwargs=self.request.kwargs)
+                            task_kwargs=self.request.get("kwargs"))
 
     @classmethod
     def establish_connection(self, connect_timeout=None):

+ 76 - 1
celery/tests/test_result.py

@@ -6,7 +6,8 @@ from celery import states
 from celery.app import app_or_default
 from celery.utils import gen_unique_id
 from celery.utils.compat import all
-from celery.result import AsyncResult, TaskSetResult
+from celery.utils.serialization import pickle
+from celery.result import AsyncResult, EagerResult, TaskSetResult
 from celery.exceptions import TimeoutError
 from celery.task.base import Task
 
@@ -47,6 +48,15 @@ class TestAsyncResult(unittest.TestCase):
         for task in (self.task1, self.task2, self.task3, self.task4):
             save_result(task)
 
+    def test_reduce(self):
+        a1 = AsyncResult("uuid", task_name="celery.ping")
+        restored = pickle.loads(pickle.dumps(a1))
+        self.assertEqual(restored.task_id, "uuid")
+        self.assertEqual(restored.task_name, "celery.ping")
+
+        a2 = AsyncResult("uuid")
+        self.assertEqual(pickle.loads(pickle.dumps(a2)).task_id, "uuid")
+
     def test_successful(self):
         ok_res = AsyncResult(self.task1["id"])
         nok_res = AsyncResult(self.task3["id"])
@@ -108,6 +118,7 @@ class TestAsyncResult(unittest.TestCase):
         self.assertEqual(ok2_res.get(), "quick")
         self.assertRaises(KeyError, nok_res.get)
         self.assertIsInstance(nok2_res.result, KeyError)
+        self.assertEqual(ok_res.info, "the")
 
     def test_get_timeout(self):
         res = AsyncResult(self.task4["id"])             # has RETRY status
@@ -143,6 +154,10 @@ class MockAsyncResultFailure(AsyncResult):
 
 
 class MockAsyncResultSuccess(AsyncResult):
+    forgotten = False
+
+    def forget(self):
+        self.forgotten = True
 
     @property
     def result(self):
@@ -153,6 +168,16 @@ class MockAsyncResultSuccess(AsyncResult):
         return states.SUCCESS
 
 
+class SimpleBackend(object):
+        ids = []
+
+        def __init__(self, ids=[]):
+            self.ids = ids
+
+        def get_many(self, *args, **kwargs):
+            return ((id, {"result": i}) for i, id in enumerate(self.ids))
+
+
 class TestTaskSetResult(unittest.TestCase):
 
     def setUp(self):
@@ -168,6 +193,49 @@ class TestTaskSetResult(unittest.TestCase):
         it = iter(ts)
         self.assertRaises(KeyError, it.next)
 
+    def test_forget(self):
+        subs = [MockAsyncResultSuccess(gen_unique_id()),
+                MockAsyncResultSuccess(gen_unique_id())]
+        ts = TaskSetResult(gen_unique_id(), subs)
+        ts.forget()
+        for sub in subs:
+            self.assertTrue(sub.forgotten)
+
+    def test_getitem(self):
+        subs = [MockAsyncResultSuccess(gen_unique_id()),
+                MockAsyncResultSuccess(gen_unique_id())]
+        ts = TaskSetResult(gen_unique_id(), subs)
+        self.assertIs(ts[0], subs[0])
+
+    def test_save_restore(self):
+        subs = [MockAsyncResultSuccess(gen_unique_id()),
+                MockAsyncResultSuccess(gen_unique_id())]
+        ts = TaskSetResult(gen_unique_id(), subs)
+        ts.save()
+        self.assertRaises(AttributeError, ts.save, backend=object())
+        self.assertEqual(TaskSetResult.restore(ts.taskset_id).subtasks,
+                         ts.subtasks)
+        self.assertRaises(AttributeError,
+                          TaskSetResult.restore, ts.taskset_id,
+                          backend=object())
+
+    def test_join_native(self):
+        backend = SimpleBackend()
+        subtasks = [AsyncResult(gen_unique_id(), backend=backend)
+                        for i in range(10)]
+        ts = TaskSetResult(gen_unique_id(), subtasks)
+        backend.ids = [subtask.task_id for subtask in subtasks]
+        res = ts.join_native()
+        self.assertEqual(res, range(10))
+
+    def test_iter_native(self):
+        backend = SimpleBackend()
+        subtasks = [AsyncResult(gen_unique_id(), backend=backend)
+                        for i in range(10)]
+        ts = TaskSetResult(gen_unique_id(), subtasks)
+        backend.ids = [subtask.task_id for subtask in subtasks]
+        self.assertEqual(len(list(ts.iter_native())), 10)
+
     def test_iterate_yields(self):
         ar = MockAsyncResultSuccess(gen_unique_id())
         ar2 = MockAsyncResultSuccess(gen_unique_id())
@@ -302,6 +370,13 @@ class TestEagerResult(unittest.TestCase):
         res = RaisingTask.apply(args=[3, 3])
         self.assertRaises(KeyError, res.wait)
 
+    def test_wait(self):
+        res = EagerResult("x", "x", states.RETRY)
+        res.wait()
+        self.assertEqual(res.state, states.RETRY)
+        self.assertEqual(res.status, states.RETRY)
+
     def test_revoke(self):
         res = RaisingTask.apply(args=[3, 3])
         self.assertFalse(res.revoke())
+