Bläddra i källkod

got tests all passing for database taskmeta backend

Brad Jasper 15 år sedan
förälder
incheckning
f437c20452

+ 13 - 0
celery/backends/base.py

@@ -106,6 +106,19 @@ class BaseBackend(object):
         """
         pass
 
+    def store_taskset(self, taskset_id, result):
+        """Store the result and status of a task."""
+        raise NotImplementedError(
+                "store_taskset is not supported by this backend.")
+
+    def get_taskset(self, task_id):
+        """Get the result of a taskset."""
+        raise NotImplementedError(
+                "get_taskset is not supported by this backend.")
+
+
+
+
 
 class KeyValueStoreBackend(BaseBackend):
 

+ 20 - 0
celery/backends/database.py

@@ -40,6 +40,11 @@ class Backend(BaseBackend):
                                       traceback=traceback)
         return result
 
+    def store_taskset(self, taskset_id, result):
+        """Store the result of an executed taskset."""
+        TaskSetMeta.objects.store_result(taskset_id, result)
+        return result
+
     def is_done(self, task_id):
         """Returns ``True`` if task with ``task_id`` has been executed."""
         return self.get_status(task_id) == "DONE"
@@ -69,6 +74,21 @@ class Backend(BaseBackend):
             self._cache[task_id] = meta
         return meta
 
+    def get_taskset(self, taskset_id):
+        """Get the result for a taskset."""
+        meta = self._get_taskset_meta_for(taskset_id)
+        if meta:
+            return meta.result
+
+    def _get_taskset_meta_for(self, taskset_id):
+        """Get taskset metadata for a taskset by id."""
+        if taskset_id in self._cache:
+            return self._cache[taskset_id]
+        meta = TaskSetMeta.objects.get_taskset(taskset_id)
+        if meta:
+            self._cache[taskset_id] = meta
+            return meta
+
     def cleanup(self):
         """Delete expired metadata."""
         TaskMeta.objects.delete_expired()

+ 43 - 46
celery/managers.py

@@ -59,52 +59,6 @@ class MySQLTableLock(TableLock):
 TABLE_LOCK_FOR_ENGINE = {"mysql": MySQLTableLock}
 table_lock = TABLE_LOCK_FOR_ENGINE.get(settings.DATABASE_ENGINE, TableLock)
 
-
-class TaskSetManager(models.Manager):
-    """Manager for :class:`celery.models.TaskSet` models."""
-
-    def get_all_expired(self):
-        """Get all expired taskset results."""
-        return self.filter(date_done__lt=datetime.now() - TASK_RESULT_EXPIRES)
-
-    def delete_expired(self):
-        """Delete all expired taskset results."""
-        self.get_all_expired().delete()
-
-    def get_result(self, taskset_id):
-        """Get task meta for task by ``taskset_id``."""
-        try:
-            return self.get(taskset_id=taskset_id)
-        except self.model.DoesNotExist:
-            return None
-
-    def store_result(self, taskset_id, result, exception_retry=True):
-        """Store the result of a taskset.
-
-        :param taskset_id: task set id
-
-        :param result: The return value of the taskset
-
-        """
-        try:
-            taskset, created = self.get_or_create(taskset_id=taskset_id, defaults={
-                                                "result": result})
-            if not created:
-                taskset.result = result
-                taskset.save()
-        except Exception, exc:
-            # depending on the database backend we can get various exceptions.
-            # for excample, psycopg2 raises an exception if some operation
-            # breaks transaction, and saving task result won't be possible
-            # until we rollback transaction
-            if exception_retry:
-                transaction.rollback_unless_managed()
-                self.store_result(taskset_id, result, False)
-            else:
-                raise
-
-
-
 class TaskManager(models.Manager):
     """Manager for :class:`celery.models.Task` models."""
 
@@ -165,6 +119,49 @@ class TaskManager(models.Manager):
             else:
                 raise
 
+class TaskSetManager(models.Manager):
+    """Manager for :class:`celery.models.TaskSet` models."""
+
+    def get_all_expired(self):
+        """Get all expired taskset results."""
+        return self.filter(date_done__lt=datetime.now() - TASK_RESULT_EXPIRES)
+
+    def delete_expired(self):
+        """Delete all expired taskset results."""
+        self.get_all_expired().delete()
+
+    def get_taskset(self, taskset_id):
+        """Get taskset meta for task by ``taskset_id``."""
+        try:
+            return self.get(taskset_id=taskset_id)
+        except self.model.DoesNotExist:
+            return None
+
+    def store_result(self, taskset_id, result, exception_retry=True):
+        """Store the result of a taskset.
+
+        :param taskset_id: task set id
+
+        :param result: The return value of the taskset
+
+        """
+        try:
+            taskset, created = self.get_or_create(taskset_id=taskset_id, defaults={
+                                                "result": result})
+            if not created:
+                taskset.result = result
+                taskset.save()
+        except Exception, exc:
+            # depending on the database backend we can get various exceptions.
+            # for excample, psycopg2 raises an exception if some operation
+            # breaks transaction, and saving task result won't be possible
+            # until we rollback transaction
+            if exception_retry:
+                transaction.rollback_unless_managed()
+                self.store_result(taskset_id, result, False)
+            else:
+                raise
+
 
 class PeriodicTaskManager(models.Manager):
     """Manager for :class:`celery.models.PeriodicTask` models."""

+ 4 - 1
celery/task/base.py

@@ -8,6 +8,7 @@ from celery.utils import gen_unique_id, get_full_cls_name
 from celery.registry import tasks
 from celery.serialization import pickle
 from celery.exceptions import MaxRetriesExceededError, RetryTaskError
+from celery.backends import default_backend
 from datetime import timedelta
 
 
@@ -540,7 +541,9 @@ class TaskSet(object):
                         for args, kwargs in self.arguments]
         publisher.close()
         conn.close()
-        return TaskSetResult(taskset_id, subtasks)
+        result = TaskSetResult(taskset_id, subtasks)
+        default_backend.store_taskset(taskset_id, result)
+        return result
 
     def join(self, timeout=None):
         """Gather the results for all of the tasks in the taskset,

+ 9 - 0
celery/tests/test_backends/test_base.py

@@ -34,6 +34,15 @@ class TestBaseBackendInterface(unittest.TestCase):
         self.assertRaises(NotImplementedError,
                 b.get_result, "SOMExx-N0nex1stant-IDxx-")
 
+    def test_get_taskset(self):
+        self.assertRaises(NotImplementedError,
+                b.get_taskset, "SOMExx-N0nex1stant-IDxx-")
+
+    def test_store_taskset(self):
+        self.assertRaises(NotImplementedError,
+                b.store_taskset, "SOMExx-N0nex1stant-IDxx-", "blergh")
+
+
 
 class TestPickleException(unittest.TestCase):
 

+ 14 - 0
celery/tests/test_backends/test_database.py

@@ -67,3 +67,17 @@ class TestDatabaseBackend(unittest.TestCase):
         self.assertFalse(b.is_done(tid3))
         self.assertEquals(b.get_status(tid3), "FAILURE")
         self.assertTrue(isinstance(b.get_result(tid3), KeyError))
+
+    def test_taskset_store(self):
+        b = Backend()
+        tid = gen_unique_id()
+
+        self.assertTrue(b.get_taskset(tid) is None)
+
+        result = {"foo": "baz", "bar": SomeClass(12345)}
+        b.store_taskset(tid, result)
+        rindb = b.get_taskset(tid)
+        self.assertTrue(rindb is not None)
+        self.assertEquals(rindb.get("foo"), "baz")
+        self.assertEquals(rindb.get("bar").data, 12345)
+        self.assertTrue(b._cache.get(tid))

+ 1 - 1
celery/tests/test_models.py

@@ -64,7 +64,7 @@ class TestModels(unittest.TestCase):
         self.assertTrue(m1.taskset_id)
         self.assertTrue(isinstance(m1.date_done, datetime))
 
-        self.assertEquals(TaskSetMeta.objects.get_result(m1.taskset_id).taskset_id,
+        self.assertEquals(TaskSetMeta.objects.get_taskset(m1.taskset_id).taskset_id,
                 m1.taskset_id)
 
         # Have to avoid save() because it applies the auto_now=True.