Browse Source

Tests for KeyValueStore support for saving/restoring taskset results.

Ask Solem 15 years ago
parent
commit
773101fc44
3 changed files with 23 additions and 2 deletions
  1. 1 1
      celery/backends/base.py
  2. 10 1
      celery/result.py
  3. 12 0
      celery/tests/test_backends/test_cache.py

+ 1 - 1
celery/backends/base.py

@@ -190,7 +190,7 @@ class KeyValueStoreBackend(BaseDictBackend):
 
     def _store_taskset(self, taskset_id, result):
         meta = {"result": result}
-        self.set(self.get_key_for_taskset(task_id), pickle.dumps(meta))
+        self.set(self.get_key_for_taskset(taskset_id), pickle.dumps(meta))
         return result
 
     def _get_task_meta_for(self, task_id):

+ 10 - 1
celery/result.py

@@ -88,9 +88,17 @@ class BaseAsyncResult(object):
         """``str(self)`` -> ``self.task_id``"""
         return self.task_id
 
+    def __hash__(self):
+        return hash(self.task_id)
+
     def __repr__(self):
         return "<AsyncResult: %s>" % self.task_id
 
+    def __eq__(self, other):
+        if isinstance(other, self.__class__):
+            return self.task_id == other.task_id
+        return self == other
+
     @property
     def result(self):
         """When the task has been executed, this contains the return value.
@@ -134,6 +142,7 @@ class BaseAsyncResult(object):
         return self.backend.get_status(self.task_id)
 
 
+
 class AsyncResult(BaseAsyncResult):
     """Pending task result using the default backend.
 
@@ -312,7 +321,7 @@ class TaskSetResult(object):
             >>> result = TaskSetResult.restore(task_id)
 
         """
-        backend.store_taskset(taskset_id, result)
+        backend.store_taskset(self.taskset_id, self)
 
     @classmethod
     def restore(self, taskset_id, backend=default_backend):

+ 12 - 0
celery/tests/test_backends/test_cache.py

@@ -3,6 +3,7 @@ import unittest
 
 from billiard.serialization import pickle
 
+from celery import result
 from celery import states
 from celery.utils import gen_unique_id
 from celery.backends.cache import CacheBackend
@@ -33,6 +34,17 @@ class TestCacheBackend(unittest.TestCase):
         self.assertTrue(cb._cache.get(tid))
         self.assertTrue(cb.get_result(tid), 42)
 
+    def test_save_restore_taskset(self):
+        backend = CacheBackend()
+        taskset_id = gen_unique_id()
+        subtask_ids = [gen_unique_id() for i in range(10)]
+        subtasks = map(result.AsyncResult, subtask_ids)
+        res = result.TaskSetResult(taskset_id, subtasks)
+        res.save(backend=backend)
+        saved = result.TaskSetResult.restore(taskset_id, backend=backend)
+        self.assertEquals(saved.subtasks, subtasks)
+        self.assertEquals(saved.taskset_id, taskset_id)
+
     def test_is_pickled(self):
         cb = CacheBackend()