Browse Source

Redis unit test improvements

Ask Solem 11 years ago
parent
commit
84ab4555a3
2 changed files with 55 additions and 9 deletions
  1. 10 5
      celery/backends/redis.py
  2. 45 4
      celery/tests/backends/test_redis.py

+ 10 - 5
celery/backends/redis.py

@@ -84,22 +84,27 @@ class RedisBackend(KeyValueStoreBackend):
         connparams = dict(
             defaults, **dictfilter({
                 'host': host, 'port': port, 'password': password,
-                'db': int(query.pop('virtual_host', None) or 0)})
+                'db': query.pop('virtual_host', None)})
         )
 
         if scheme == 'socket':
-            # Use 'path' as path to the socket... in this case
+            # use 'path' as path to the socket… in this case
             # the database number should be given in 'query'
             connparams.update({
                 'connection_class': self.redis.UnixDomainSocketConnection,
                 'path': '/' + path,
             })
+            # host+port are invalid options when using this connection type.
             connparams.pop('host', None)
             connparams.pop('port', None)
         else:
-            path = path.strip('/') if isinstance(path, string_t) else path
-            if path:
-                connparams['db'] = int(path)
+            connparams['db'] = path
+
+        # db may be string and start with / like in kombu.
+        db = connparams.get('db') or 0
+        db = db.strip('/') if isinstance(db, string_t) else db
+        connparams['db'] = int(db)
+
         # Query parameters override other parameters
         connparams.update(query)
         return connparams

+ 45 - 4
celery/tests/backends/test_redis.py

@@ -10,7 +10,7 @@ from celery import signature
 from celery import states
 from celery import group
 from celery.datastructures import AttributeDict
-from celery.exceptions import ImproperlyConfigured
+from celery.exceptions import CPendingDeprecationWarning, ImproperlyConfigured
 from celery.utils.timeutils import timedelta_seconds
 
 from celery.tests.case import (
@@ -63,6 +63,11 @@ class redis(object):
         def __init__(self, **kwargs):
             pass
 
+    class UnixDomainSocketConnection(object):
+
+        def __init__(self, **kwargs):
+            pass
+
 
 class test_RedisBackend(AppCase):
 
@@ -100,9 +105,45 @@ class test_RedisBackend(AppCase):
             self.MockBackend(app=self.app)
 
     def test_url(self):
-        x = self.MockBackend('redis://foobar//1', app=self.app)
-        self.assertEqual(x.host, 'foobar')
-        self.assertEqual(x.db, 1)
+        x = self.MockBackend(
+            'redis://:bosco@vandelay.com:123//1', app=self.app,
+        )
+        self.assertTrue(x.connparams)
+        self.assertEqual(x.connparams['host'], 'vandelay.com')
+        self.assertEqual(x.connparams['db'], 1)
+        self.assertEqual(x.connparams['port'], 123)
+        self.assertEqual(x.connparams['password'], 'bosco')
+
+    def test_socket_url(self):
+        x = self.MockBackend(
+            'socket:///tmp/redis.sock?virtual_host=/3', app=self.app,
+        )
+        self.assertTrue(x.connparams)
+        self.assertEqual(x.connparams['path'], '/tmp/redis.sock')
+        self.assertIs(
+            x.connparams['connection_class'],
+            redis.UnixDomainSocketConnection,
+        )
+        self.assertNotIn('host', x.connparams)
+        self.assertNotIn('port', x.connparams)
+        self.assertEqual(x.connparams['db'], 3)
+
+    def test_compat_propertie(self):
+        x = self.MockBackend(
+            'redis://:bosco@vandelay.com:123//1', app=self.app,
+        )
+        with self.assertWarnsRegex(CPendingDeprecationWarning,
+                                   r'scheduled for deprecation'):
+            self.assertEqual(x.host, 'vandelay.com')
+        with self.assertWarnsRegex(CPendingDeprecationWarning,
+                                   r'scheduled for deprecation'):
+            self.assertEqual(x.db, 1)
+        with self.assertWarnsRegex(CPendingDeprecationWarning,
+                                   r'scheduled for deprecation'):
+            self.assertEqual(x.port, 123)
+        with self.assertWarnsRegex(CPendingDeprecationWarning,
+                                   r'scheduled for deprecation'):
+            self.assertEqual(x.password, 'bosco')
 
     def test_conf_raises_KeyError(self):
         self.app.conf = AttributeDict({