Browse Source

100% Coverage for celery.backends.pyredis

Ask Solem 14 years ago
parent
commit
8f09b6967c

+ 33 - 51
celery/backends/pyredis.py

@@ -3,6 +3,7 @@ from datetime import timedelta
 from celery.backends.base import KeyValueStoreBackend
 from celery.exceptions import ImproperlyConfigured
 from celery.utils import timeutils
+from celery.utils import cached_property
 
 try:
     import redis
@@ -13,31 +14,26 @@ except ImportError:
 
 
 class RedisBackend(KeyValueStoreBackend):
-    """Redis based task backend store.
+    """Redis task result store."""
 
-    .. attribute:: redis_host
-
-        The hostname to the Redis server.
-
-    .. attribute:: redis_port
-
-        The port to the Redis server.
-
-        Raises :class:`celery.exceptions.ImproperlyConfigured` if
-        the :setting:`REDIS_HOST` or :setting:`REDIS_PORT` settings is not set.
-
-    """
+    #: redis-py client module.
     redis = redis
+
+    #: default Redis server hostname (`localhost`).
     redis_host = "localhost"
+
+    #: default Redis server port (6379)
     redis_port = 6379
     redis_db = 0
+
+    #: default Redis password (:const:`None`)
     redis_password = None
 
     def __init__(self, redis_host=None, redis_port=None, redis_db=None,
             redis_password=None,
             expires=None, **kwargs):
         super(RedisBackend, self).__init__(**kwargs)
-        if redis is None:
+        if self.redis is None:
             raise ImproperlyConfigured(
                     "You need to install the redis library in order to use "
                   + "Redis result store backend.")
@@ -61,48 +57,34 @@ class RedisBackend(KeyValueStoreBackend):
             self.expires = timeutils.timedelta_seconds(self.expires)
         if self.expires is not None:
             self.expires = int(self.expires)
+        self.redis_port = int(self.redis_port)
 
-        if self.redis_port:
-            self.redis_port = int(self.redis_port)
-        if not self.redis_host or not self.redis_port:
-            raise ImproperlyConfigured(
-                "In order to use the Redis result store backend, you have to "
-                "set the REDIS_HOST and REDIS_PORT settings")
-        self._connection = None
-
-    def open(self):
-        """Get :class:`redis.Redis` instance with the current
-        server configuration.
-
-        The connection is then cached until you do an
-        explicit :meth:`close`.
-
-        """
-        # connection overrides bool()
-        if self._connection is None:
-            self._connection = self.redis.Redis(host=self.redis_host,
-                                                port=self.redis_port,
-                                                db=self.redis_db,
-                                                password=self.redis_password)
-        return self._connection
+    def get(self, key):
+        return self.client.get(key)
+
+    def set(self, key, value):
+        client = self.client
+        client.set(key, value)
+        if self.expires is not None:
+            client.expire(key, self.expires)
+
+    def delete(self, key):
+        self.client.delete(key)
 
     def close(self):
-        """Close the connection to redis."""
-        if self._connection is not None:
-            self._connection.connection.disconnect()
-            self._connection = None
+        """Closes the Redis connection."""
+        del(self.client)
 
     def process_cleanup(self):
         self.close()
 
-    def get(self, key):
-        return self.open().get(key)
+    @cached_property
+    def client(self):
+        return self.redis.Redis(host=self.redis_host,
+                                port=self.redis_port,
+                                db=self.redis_db,
+                                password=self.redis_password)
 
-    def set(self, key, value):
-        r = self.open()
-        r.set(key, value)
-        if self.expires is not None:
-            r.expire(key, self.expires)
-
-    def delete(self, key):
-        self.open().delete(key)
+    @client.deleter
+    def client(self, client):
+        client.connection.disconnect()

+ 13 - 30
celery/tests/test_backends/test_redis.py

@@ -43,9 +43,8 @@ def get_redis_or_SkipTest():
     try:
         tb = RedisBackend(redis_db="celery_unittest")
         try:
-            tb.open()
             # Evaluate lazy connection
-            tb._connection.connection.connect(tb._connection)
+            tb.client.connection.connect(tb.client)
         except ConnectionError, exc:
             emit_no_redis_msg("not running")
             raise SkipTest("can't connect to redis: %s" % (exc, ))
@@ -58,15 +57,6 @@ def get_redis_or_SkipTest():
 
 class TestRedisBackend(unittest.TestCase):
 
-    def test_cached_connection(self):
-        tb = get_redis_or_SkipTest()
-
-        self.assertIsNotNone(tb._connection)
-        tb.close()
-        self.assertIsNone(tb._connection)
-        tb.open()
-        self.assertIsNotNone(tb._connection)
-
     def test_mark_as_done(self):
         tb = get_redis_or_SkipTest()
 
@@ -102,27 +92,20 @@ class TestRedisBackend(unittest.TestCase):
         self.assertEqual(tb.get_status(tid3), states.FAILURE)
         self.assertIsInstance(tb.get_result(tid3), KeyError)
 
-    def test_process_cleanup(self):
-        tb = get_redis_or_SkipTest()
-
-        tb.process_cleanup()
-
-        self.assertIsNone(tb._connection)
-
     def test_connection_close_if_connected(self):
         tb = get_redis_or_SkipTest()
 
-        tb.open()
-        self.assertIsNotNone(tb._connection)
-        tb.close()
-        self.assertIsNone(tb._connection)
-        tb.close()
-        self.assertIsNone(tb._connection)
+        client = tb.client
+        self.assertIsNotNone(client)
+        tb.process_cleanup()
+        self.assertRaises(KeyError, tb.__dict__.__getitem__, "client")
+        tb.process_cleanup()
+        self.assertRaises(KeyError, tb.__dict__.__getitem__, "client")
 
 
-class TestTyrantBackendNoTyrant(unittest.TestCase):
+class TestRedisBackendNoRedis(unittest.TestCase):
 
-    def test_tyrant_None_if_tyrant_not_installed(self):
+    def test_redis_None_if_redis_not_installed(self):
         prev = sys.modules.pop("celery.backends.pyredis")
         try:
             def with_redis_masked(_val):
@@ -133,11 +116,11 @@ class TestTyrantBackendNoTyrant(unittest.TestCase):
         finally:
             sys.modules["celery.backends.pyredis"] = prev
 
-    def test_constructor_raises_if_tyrant_not_installed(self):
+    def test_constructor_raises_if_redis_not_installed(self):
         from celery.backends import pyredis
-        prev = pyredis.redis
-        pyredis.redis = None
+        prev = pyredis.RedisBackend.redis
+        pyredis.RedisBackend.redis = None
         try:
             self.assertRaises(ImproperlyConfigured, pyredis.RedisBackend)
         finally:
-            pyredis.redis = prev
+            pyredis.RedisBackend.redis = prev

+ 102 - 0
celery/tests/test_backends/test_redis_unit.py

@@ -0,0 +1,102 @@
+from datetime import timedelta
+
+from celery import states
+from celery.app import app_or_default
+from celery.utils import gen_unique_id
+
+from celery.tests.utils import unittest
+
+
+class Redis(object):
+
+    class Connection(object):
+        connected = True
+
+        def disconnect(self):
+            self.connected = False
+
+    def __init__(self, host=None, port=None, db=None, password=None, **kw):
+        self.host = host
+        self.port = port
+        self.db = db
+        self.password = password
+        self.connection = self.Connection()
+        self.keyspace = {}
+        self.expiry = {}
+
+    def get(self, key):
+        return self.keyspace.get(key)
+
+    def set(self, key, value):
+        self.keyspace[key] = value
+
+    def expire(self, key, expires):
+        self.expiry[key] = expires
+
+    def delete(self, key):
+        self.keyspace.pop(key)
+
+
+class redis(object):
+    Redis = Redis
+
+
+class test_RedisBackend(unittest.TestCase):
+
+    def get_backend(self):
+        from celery.backends import pyredis
+
+        class RedisBackend(pyredis.RedisBackend):
+            redis = redis
+
+        return RedisBackend
+
+    def setUp(self):
+        self.Backend = self.get_backend()
+
+    def test_expires_defaults_to_config(self):
+        app = app_or_default()
+        prev = app.conf.CELERY_AMQP_TASK_RESULT_EXPIRES
+        app.conf.CELERY_TASK_RESULT_EXPIRES = 10
+        try:
+            b = self.Backend(expires=None)
+            self.assertEqual(b.expires, 10)
+        finally:
+            app.conf.CELERY_TASK_RESULT_EXPIRES = prev
+
+    def test_expires_is_int(self):
+        b = self.Backend(expires=48)
+        self.assertEqual(b.expires, 48)
+
+    def test_expires_is_None(self):
+        b = self.Backend(expires=None)
+        self.assertIsNone(b.expires)
+
+    def test_expires_is_timedelta(self):
+        b = self.Backend(expires=timedelta(minutes=1))
+        self.assertEqual(b.expires, 60)
+
+    def test_get_set_forget(self):
+        b = self.Backend()
+        uuid = gen_unique_id()
+        b.store_result(uuid, 42, states.SUCCESS)
+        self.assertEqual(b.get_status(uuid), states.SUCCESS)
+        self.assertEqual(b.get_result(uuid), 42)
+        b.forget(uuid)
+        self.assertEqual(b.get_status(uuid), states.PENDING)
+
+    def test_set_expires(self):
+        b = self.Backend(expires=512)
+        uuid = gen_unique_id()
+        key = b.get_key_for_task(uuid)
+        b.store_result(uuid, 42, states.SUCCESS)
+        self.assertEqual(b.client.expiry[key], 512)
+
+    def test_closes_connection_at_process_cleanup(self):
+        b = self.Backend(expires=512)
+        client = b.client
+        self.assertTrue(client.connection.connected)
+        b.process_cleanup()
+        self.assertFalse(client.connection.connected)
+        b.process_cleanup()
+        self.assertFalse(client.connection.connected)

+ 5 - 4
celery/utils/__init__.py

@@ -361,9 +361,10 @@ class cached_property(object):
             return value
 
         @connection.deleter
-        def connection(self):
+        def connection(self, value):
             # Additional action to do at del(self.attr)
-            print("Next access will give a new connection")
+            if value is not None:
+                print("Connection %r deleted" % (value, ))
 
     """
 
@@ -395,12 +396,12 @@ class cached_property(object):
         if obj is None:
             return self
         try:
-            del(obj.__dict__[self.__name__])
+            value = obj.__dict__.pop(self.__name__)
         except KeyError:
             pass
         else:
             if self.__del is not None:
-                self.__del(obj)
+                self.__del(obj, value)
 
     def setter(self, fset):
         return self.__class__(self.__get, fset, self.__del)

+ 0 - 6
setup.cfg

@@ -4,12 +4,9 @@ cover3-branch = 1
 cover3-html = 1
 cover3-package = celery
 cover3-exclude = celery
-                 celery.conf
                  celery.tests.*
-                 celery.bin.celeryev
                  celery.bin.celeryd_multi
                  celery.bin.celeryd_detach
-                 celery.bin.celeryctl
                  celery.bin.camqadm
                  celery.execute
                  celery.platforms
@@ -19,17 +16,14 @@ cover3-exclude = celery
                  celery.utils.functional
                  celery.utils.dispatch*
                  celery.utils.term
-                 celery.utils.timer2
                  celery.db.a805d4bd
                  celery.db.dfd042c7
                  celery.contrib*
                  celery.concurrency.threads
                  celery.concurrency.processes.pool
                  celery.concurrency.evg
-                 celery.concurrency.evlet
                  celery.backends.mongodb
                  celery.backends.tyrant
-                 celery.backends.pyredis
                  celery.backends.cassandra
                  celery.events.dumper
                  celery.events.cursesmon