Browse Source

Clean up the backends

Ask Solem 15 years ago
parent
commit
a462e42dcc

+ 58 - 21
celery/backends/base.py

@@ -20,6 +20,9 @@ class BaseBackend(object):
 
     capabilities = []
 
+    def __init__(self, *args, **kwargs):
+        pass
+
     def encode_result(self, result, status):
         if status == states.SUCCESS:
             return self.prepare_value(result)
@@ -118,41 +121,33 @@ class BaseBackend(object):
         raise NotImplementedError(
                 "store_taskset is not supported by this backend.")
 
-    def get_taskset(self, task_id):
+    def get_taskset(self, taskset_id):
         """Get the result of a taskset."""
         raise NotImplementedError(
                 "get_taskset is not supported by this backend.")
 
 
-class KeyValueStoreBackend(BaseBackend):
+class BaseDictBackend(BaseBackend):
 
     capabilities = ["ResultStore"]
 
     def __init__(self, *args, **kwargs):
-        super(KeyValueStoreBackend, self).__init__()
+        super(BaseDictBackend, self).__init__(*args, **kwargs)
         self._cache = {}
 
-    def get_cache_key_for_task(self, task_id):
-        """Get the cache key for a task by id."""
-        return "celery-task-meta-%s" % task_id
-
-    def get(self, key):
-        raise NotImplementedError("Must implement the get method.")
-
-    def set(self, key, value):
-        raise NotImplementedError("Must implement the set method.")
-
     def store_result(self, task_id, result, status, traceback=None):
         """Store task result and status."""
         result = self.encode_result(result, status)
-        meta = {"status": status, "result": result, "traceback": traceback}
-        self.set(self.get_cache_key_for_task(task_id), pickle.dumps(meta))
-        return result
+        return self._store_result(task_id, result, status, traceback)
 
     def get_status(self, task_id):
         """Get the status of a task."""
         return self._get_task_meta_for(task_id)["status"]
 
+    def get_traceback(self, task_id):
+        """Get the traceback for a failed task."""
+        return self._get_task_meta_for(task_id)["traceback"]
+
     def get_result(self, task_id):
         """Get the result of a task."""
         meta = self._get_task_meta_for(task_id)
@@ -161,19 +156,61 @@ class KeyValueStoreBackend(BaseBackend):
         else:
             return meta["result"]
 
-    def get_traceback(self, task_id):
-        """Get the traceback for a failed task."""
-        meta = self._get_task_meta_for(task_id)
-        return meta["traceback"]
+    def get_taskset(self, taskset_id):
+        """Get the result for a taskset."""
+        meta = self._get_taskset_meta_for(taskset_id)
+        if meta:
+            return meta["result"]
+
+    def store_taskset(self, taskset_id, result):
+        """Store the result of an executed taskset."""
+        return self._store_taskset(taskset_id, result)
+
+
+class KeyValueStoreBackend(BaseDictBackend):
+
+    def get(self, key):
+        raise NotImplementedError("Must implement the get method.")
+
+    def set(self, key, value):
+        raise NotImplementedError("Must implement the set method.")
+
+    def get_key_for_task(self, task_id):
+        """Get the cache key for a task by id."""
+        return "celery-task-meta-%s" % task_id
+
+    def get_key_for_taskset(self, task_id):
+        """Get the cache key for a task by id."""
+        return "celery-taskset-meta-%s" % task_id
+
+    def _store_result(self, task_id, result, status, traceback=None):
+        meta = {"status": status, "result": result, "traceback": traceback}
+        self.set(self.get_key_for_task(task_id), pickle.dumps(meta))
+        return result
+
+    def _store_taskset(self, taskset_id, result):
+        meta = {"result": result}
+        self.set(self.get_key_for_taskset(task_id), pickle.dumps(meta))
+        return result
 
     def _get_task_meta_for(self, task_id):
         """Get task metadata for a task by id."""
         if task_id in self._cache:
             return self._cache[task_id]
-        meta = self.get(self.get_cache_key_for_task(task_id))
+        meta = self.get(self.get_key_for_task(task_id))
         if not meta:
             return {"status": states.PENDING, "result": None}
         meta = pickle.loads(str(meta))
         if meta.get("status") == states.SUCCESS:
             self._cache[task_id] = meta
         return meta
+
+    def _get_taskset_meta_for(self, taskset_id):
+        """Get task metadata for a task by id."""
+        if taskset_id in self._cache:
+            return self._cache[taskset_id]
+        meta = self.get(self.get_key_for_taskset(taskset_id))
+        if meta:
+            meta = pickle.loads(str(meta))
+            self._cache[taskset_id] = meta
+            return meta

+ 8 - 37
celery/backends/database.py

@@ -1,68 +1,39 @@
 from celery import states
 from celery.models import TaskMeta, TaskSetMeta
-from celery.backends.base import BaseBackend
+from celery.backends.base import BaseDictBackend
 
 
-class DatabaseBackend(BaseBackend):
+class DatabaseBackend(BaseDictBackend):
     """The database backends. Using Django models to store task metadata."""
 
-    capabilities = ["ResultStore"]
-
-    def __init__(self, *args, **kwargs):
-        super(DatabaseBackend, self).__init__(*args, **kwargs)
-        self._cache = {}
-
-    def store_result(self, task_id, result, status, traceback=None):
+    def _store_result(self, task_id, result, status, traceback=None):
         """Store return value and status of an executed task."""
-        result = self.encode_result(result, status)
         TaskMeta.objects.store_result(task_id, result, status,
                                       traceback=traceback)
         return result
 
-    def store_taskset(self, taskset_id, result):
+    def _store_taskset(self, taskset_id, result):
         """Store the result of an executed taskset."""
         TaskSetMeta.objects.store_result(taskset_id, result)
         return result
 
-    def get_status(self, task_id):
-        """Get the status of a task."""
-        return self._get_task_meta_for(task_id).status
-
-    def get_traceback(self, task_id):
-        """Get the traceback of a failed task."""
-        return self._get_task_meta_for(task_id).traceback
-
-    def get_result(self, task_id):
-        """Get the result for a task."""
-        meta = self._get_task_meta_for(task_id)
-        if meta.status in states.EXCEPTION_STATES:
-            return self.exception_to_python(meta.result)
-        else:
-            return meta.result
-
     def _get_task_meta_for(self, task_id):
         """Get task metadata for a task by id."""
         if task_id in self._cache:
             return self._cache[task_id]
-        meta = TaskMeta.objects.get_task(task_id)
-        if meta.status == states.SUCCESS:
+        meta = TaskMeta.objects.get_task(task_id).to_dict()
+        if meta["status"] == states.SUCCESS:
             self._cache[task_id] = meta
         return meta
 
-    def get_taskset(self, taskset_id):
-        """Get the result for a taskset."""
-        meta = self._get_taskset_meta_for(taskset_id)
-        if meta:
-            return meta.result
-
     def _get_taskset_meta_for(self, taskset_id):
         """Get taskset metadata for a taskset by id."""
         if taskset_id in self._cache:
             return self._cache[taskset_id]
-        meta = TaskSetMeta.objects.get_taskset(taskset_id)
+        meta = TaskSetMeta.objects.get_taskset(taskset_id).to_dict()
         if meta:
             self._cache[taskset_id] = meta
-            return meta
+        return meta
 
     def cleanup(self):
         """Delete expired metadata."""

+ 3 - 23
celery/backends/mongodb.py

@@ -10,7 +10,7 @@ except ImportError:
 
 from celery import conf
 from celery import states
-from celery.backends.base import BaseBackend
+from celery.backends.base import BaseDictBackend
 from celery.loaders import load_settings
 
 
@@ -20,7 +20,7 @@ class Bunch:
         self.__dict__.update(kw)
 
 
-class MongoBackend(BaseBackend):
+class MongoBackend(BaseDictBackend):
 
     capabilities = ["ResultStore"]
 
@@ -97,12 +97,10 @@ class MongoBackend(BaseBackend):
             # goes out of scope
             self._connection = None
 
-    def store_result(self, task_id, result, status, traceback=None):
+    def _store_result(self, task_id, result, status, traceback=None):
         """Store return value and status of an executed task."""
         from pymongo.binary import Binary
 
-        result = self.encode_result(result, status)
-
         meta = {"_id": task_id,
                 "status": status,
                 "result": Binary(pickle.dumps(result)),
@@ -111,26 +109,8 @@ class MongoBackend(BaseBackend):
 
         db = self._get_database()
         taskmeta_collection = db[self.mongodb_taskmeta_collection]
-
         taskmeta_collection.save(meta, safe=True)
 
-    def get_status(self, task_id):
-        """Get status of a task."""
-        return self._get_task_meta_for(task_id)["status"]
-
-    def get_traceback(self, task_id):
-        """Get the traceback of a failed task."""
-        meta = self._get_task_meta_for(task_id)
-        return meta["traceback"]
-
-    def get_result(self, task_id):
-        """Get the result for a task."""
-        meta = self._get_task_meta_for(task_id)
-        if meta["status"] in states.EXCEPTION_STATES:
-            return self.exception_to_python(meta["result"])
-        else:
-            return meta["result"]
-
     def _get_task_meta_for(self, task_id):
         """Get task metadata for a task by id."""
         if task_id in self._cache:

+ 0 - 1
celery/contrib/batches.py

@@ -43,7 +43,6 @@ class Counter(Task):
         raise NotImplementedError("Counters must implement 'flush'")
 
 
-
 class ClickCounter(Task):
     flush_every = 1000
 

+ 13 - 0
celery/models.py

@@ -1,5 +1,6 @@
 import django
 from django.db import models
+from django.forms.models import model_to_dict
 from django.utils.translation import ugettext_lazy as _
 
 from picklefield.fields import PickledObjectField
@@ -27,6 +28,13 @@ class TaskMeta(models.Model):
         verbose_name = _(u"task meta")
         verbose_name_plural = _(u"task meta")
 
+    def to_dict(self):
+        return {"task_id": self.task_id,
+                "status": self.status,
+                "result": self.result,
+                "date_done": self.date_done,
+                "traceback": self.traceback}
+
     def __unicode__(self):
         return u"<Task: %s successful: %s>" % (self.task_id, self.status)
 
@@ -44,6 +52,11 @@ class TaskSetMeta(models.Model):
         verbose_name = _(u"taskset meta")
         verbose_name_plural = _(u"taskset meta")
 
+    def to_dict(self):
+        return {"taskset_id": self.taskset_id,
+                "result": self.result,
+                "date_done": self.date_done}
+
     def __unicode__(self):
         return u"<TaskSet: %s>" % (self.taskset_id)
 

+ 0 - 2
celery/task/base.py

@@ -413,8 +413,6 @@ class Task(object):
         wrapper.execute_using_pool(pool, loglevel, logfile)
 
 
-
-
 class ExecuteRemoteTask(Task):
     """Execute an arbitrary function or object.
 

+ 5 - 0
celery/tests/runners.py

@@ -9,6 +9,11 @@ def run_tests(test_labels, verbosity=1, interactive=True, extra_tests=None,
     """
     extra_tests = extra_tests or []
     app_labels = getattr(settings, "TEST_APPS", test_labels)
+
+    # Seems to be deleting the test database file twice :(
+    from celery.utils import noop
+    from django.db import connection
+    connection.creation.destroy_test_db = noop
     return django_test_runner(app_labels,
                               verbosity=verbosity, interactive=interactive,
                               extra_tests=extra_tests, **kwargs)

+ 0 - 1
celery/tests/test_task_http.py

@@ -76,7 +76,6 @@ class TestMutableURL(unittest.TestCase):
         self.assertEquals(url.query.get("z"), "Foo")
         self.assertEquals(url.query.get("name"), "George")
 
-
     def test_url_keeps_everything(self):
         url = "https://e.com:808/foo/bar#zeta?x=10&y=20"
         url = http.MutableURL(url)