Browse Source

Adding _forget to the MongoDB backend.

Andrew McFague 13 years ago
parent
commit
eaf78d8a81
2 changed files with 47 additions and 41 deletions
  1. 16 0
      celery/backends/mongodb.py
  2. 31 41
      celery/tests/test_backends/test_mongodb.py

+ 16 - 0
celery/backends/mongodb.py

@@ -178,6 +178,22 @@ class MongoBackend(BaseDictBackend):
         taskmeta_collection = db[self.mongodb_taskmeta_collection]
         taskmeta_collection.remove({"_id": taskset_id})
 
+    def _forget(self, task_id):
+        """
+        Remove result from MongoDB.
+        
+        :raises celery.exceptions.OperationsError: if the task_id could not be
+                                                   removed.
+        """
+
+        db = self._get_database()
+        taskmeta_collection = db[self.mongodb_taskmeta_collection]
+        
+        # By using safe=True, this will wait until it receives a response from
+        # the server.  Likewise, it will raise an OperationsError if the
+        # response was unable to be completed.
+        taskmeta_collection.remove({"_id": task_id}, safe=True)
+
     def cleanup(self):
         """Delete expired metadata."""
         db = self._get_database()

+ 31 - 41
celery/tests/test_backends/test_mongodb.py

@@ -1,64 +1,54 @@
-from __future__ import absolute_import
-
-import sys
+import uuid
 
+from mock import MagicMock, Mock, patch, sentinel
 from nose import SkipTest
 
 from celery.backends.mongodb import MongoBackend
-from celery.exceptions import ImproperlyConfigured
 from celery.tests.utils import unittest
-from celery.utils import uuid
-
-
-_no_mongo_msg = "* MongoDB %s. Will not execute related tests."
-_no_mongo_msg_emitted = False
 
 
 try:
-    from pymongo.errors import AutoReconnect
+    import pymongo
 except ImportError:
+    pymongo = None
 
-    class AutoReconnect(Exception):  # noqa
-        pass
 
+COLLECTION = "taskmeta_celery"
+TASK_ID = str(uuid.uuid1())
 
-def get_mongo_or_SkipTest():
 
-    def emit_no_mongo_msg(reason):
-        global _no_mongo_msg_emitted
-        if not _no_mongo_msg_emitted:
-            sys.stderr.write("\n" + _no_mongo_msg % reason + "\n")
-            _no_mongo_msg_emitted = True
+class TestBackendMongoDb(unittest.TestCase):
 
-    try:
-        tb = MongoBackend()
-        try:
-            tb._get_database()
-        except AutoReconnect, exc:
-            emit_no_mongo_msg("not running")
-            raise SkipTest("Can't connect to MongoDB: %s" % (exc, ))
-        return tb
-    except ImproperlyConfigured, exc:
-        if "need to install" in str(exc):
-            emit_no_mongo_msg("pymongo not installed")
-            raise SkipTest("pymongo not installed")
-        emit_no_mongo_msg("not configured")
-        raise SkipTest("MongoDB not configured correctly: %s" % (exc, ))
+    def setUp(self):
+        if pymongo is None:
+            raise SkipTest("pymongo is not installed.")
 
+        self.backend = MongoBackend()
+        self.backend.mongodb_taskmeta_collection = sentinel.collection
 
-class TestMongoBackend(unittest.TestCase):
+    @patch("celery.backends.mongodb.MongoBackend._get_database")
+    def test_forget(self, mock_get_database):
+        mock_database = MagicMock(spec=['__getitem__', '__setitem__'])
+        mock_collection = Mock()
 
-    def test_save__restore__delete_taskset(self):
-        tb = get_mongo_or_SkipTest()
+        mock_get_database.return_value = mock_database
+        mock_database.__getitem__.return_value = mock_collection
 
-        tid = uuid()
+        self.backend._forget(sentinel.task_id)
+
+        mock_get_database.assert_called_once_with()
+        mock_database.__getitem__.assert_called_once_with(sentinel.collection)
+        mock_collection.remove.assert_called_once_with(
+            {"_id": sentinel.task_id}, safe=True)
+
+    def test_save__restore__delete_taskset(self):
         res = {u"foo": "bar"}
-        self.assertEqual(tb.save_taskset(tid, res), res)
+        self.assertEqual(self.backend.save_taskset(TASK_ID, res), res)
 
-        res2 = tb.restore_taskset(tid)
+        res2 = self.backend.restore_taskset(TASK_ID)
         self.assertEqual(res2, res)
 
-        tb.delete_taskset(tid)
-        self.assertIsNone(tb.restore_taskset(tid))
+        self.backend.delete_taskset(TASK_ID)
+        self.assertIsNone(self.backend.restore_taskset(TASK_ID))
 
-        self.assertIsNone(tb.restore_taskset("xxx-nonexisting-id"))
+        self.assertIsNone(self.backend.restore_taskset("xxx-nonexisting-id"))