浏览代码

Implemented AsyncResult.forget for sqla/cache/redis/tyrant backends. (Forget and remove task result). Closes #184. Thanks to dottedmag

Ask Solem 14 年之前
父节点
当前提交
c1a6670656

+ 10 - 0
celery/backends/base.py

@@ -67,6 +67,10 @@ class BaseBackend(object):
         """Prepare value for storage."""
         return result
 
+    def forget(self, task_id):
+        raise NotImplementedError("%s does not implement forget." % (
+                    self.__class__))
+
     def wait_for(self, task_id, timeout=None):
         """Wait for task and return its result.
 
@@ -211,6 +215,9 @@ class KeyValueStoreBackend(BaseDictBackend):
     def set(self, key, value):
         raise NotImplementedError("Must implement the set method.")
 
+    def delete(self, key):
+        raise NotImplementedError("Must implement the delete method")
+
     def get_key_for_task(self, task_id):
         """Get the cache key for a task by id."""
         return "celery-task-meta-%s" % task_id
@@ -219,6 +226,9 @@ class KeyValueStoreBackend(BaseDictBackend):
         """Get the cache key for a task by id."""
         return "celery-taskset-meta-%s" % task_id
 
+    def forget(self, task_id):
+        self.delete(task_id)
+
     def _store_result(self, task_id, result, status, traceback=None):
         meta = {"status": status, "result": result, "traceback": traceback}
         self.set(self.get_key_for_task(task_id), pickle.dumps(meta))

+ 6 - 0
celery/backends/cache.py

@@ -37,6 +37,9 @@ class DummyClient(object):
     def set(self, key, value, *args, **kwargs):
         self.cache[key] = value
 
+    def delete(self, key, *args, **kwargs):
+        self.cache.pop(key, None)
+
 
 backends = {"memcache": get_best_memcache,
             "memcached": get_best_memcache,
@@ -73,6 +76,9 @@ class CacheBackend(KeyValueStoreBackend):
     def set(self, key, value):
         return self.client.set(key, value, self.expires)
 
+    def delete(self, key):
+        return self.client.delete(key)
+
     @property
     def client(self):
         if self._client is None:

+ 10 - 0
celery/backends/database.py

@@ -86,6 +86,16 @@ class DatabaseBackend(BaseDictBackend):
         finally:
             session.close()
 
+    def forget(self, task_id):
+        """Forget about result."""
+        session = self.ResultSession()
+        expires = self.result_expires
+        try:
+            session.query(Task).filter(Task.task_id == task_id).delete()
+            session.commit()
+        finally:
+            session.close()
+
     def cleanup(self):
         """Delete expired metadata."""
         session = self.ResultSession()

+ 3 - 0
celery/backends/pyredis.py

@@ -106,3 +106,6 @@ class RedisBackend(KeyValueStoreBackend):
 
     def set(self, key, value):
         self.open().set(key, value)
+
+    def delete(self, key):
+        self.open().delete(key)

+ 3 - 0
celery/backends/tyrant.py

@@ -79,3 +79,6 @@ class TyrantBackend(KeyValueStoreBackend):
 
     def set(self, key, value):
         self.open()[key] = value
+
+    def delete(self, key):
+        self.open().pop(key, None)

+ 10 - 0
celery/result.py

@@ -35,6 +35,10 @@ class BaseAsyncResult(object):
         self.backend = backend
         self.app = app_or_default(app)
 
+    def forget(self):
+        """Forget about (and possibly remove the result of) this task."""
+        self.backend.forget(self.task_id)
+
     def revoke(self, connection=None, connect_timeout=None):
         """Send revoke signal to all workers.
 
@@ -252,6 +256,12 @@ class TaskSetResult(object):
         return sum(imap(int, (subtask.successful()
                                 for subtask in self.itersubtasks())))
 
+    def forget(self):
+        """Forget about (and possible remove the result of) all the tasks
+        in this taskset."""
+        for subtask in self.subtasks:
+            subtask.forget()
+
     def revoke(self, connection=None, connect_timeout=None):
 
         def _do_revoke(connection=None, connect_timeout=None):

+ 9 - 0
celery/tests/test_backends/test_cache.py

@@ -5,6 +5,7 @@ import unittest2 as unittest
 from celery import states
 from celery.backends.cache import CacheBackend, DummyClient
 from celery.exceptions import ImproperlyConfigured
+from celery.result import AsyncResult
 from celery.utils import gen_unique_id
 
 from celery.tests.utils import mask_modules
@@ -53,6 +54,14 @@ class test_CacheBackend(unittest.TestCase):
         self.assertEqual(tb.get_status(tid3), states.FAILURE)
         self.assertIsInstance(tb.get_result(tid3), KeyError)
 
+    def test_forget(self):
+        tb = CacheBackend(backend="memory://")
+        tid = gen_unique_id()
+        tb.mark_as_done(tid, {"foo": "bar"})
+        x = AsyncResult(tid)
+        x.forget()
+        self.assertIsNone(x.result)
+
     def test_process_cleanup(self):
         tb = CacheBackend(backend="memory://")
         tb.process_cleanup()

+ 10 - 1
celery/tests/test_backends/test_database.py

@@ -6,9 +6,10 @@ from celery.exceptions import ImproperlyConfigured
 
 from celery import states
 from celery.app import app_or_default
+from celery.backends.database import DatabaseBackend
 from celery.db.models import Task, TaskSet
+from celery.result import AsyncResult
 from celery.utils import gen_unique_id
-from celery.backends.database import DatabaseBackend
 
 
 class SomeClass(object):
@@ -93,6 +94,14 @@ class test_DatabaseBackend(unittest.TestCase):
         self.assertIsInstance(tb.get_result(tid3), KeyError)
         self.assertEqual(tb.get_traceback(tid3), trace)
 
+    def test_forget(self):
+        tb = DatabaseBackend(backend="memory://")
+        tid = gen_unique_id()
+        tb.mark_as_done(tid, {"foo": "bar"})
+        x = AsyncResult(tid)
+        x.forget()
+        self.assertIsNone(x.result)
+
     def test_process_cleanup(self):
         tb = DatabaseBackend()
         tb.process_cleanup()