Explorar o código

My fix for issue #1722: unix sockets not working when redis configured as a backend

Alcides Viamontes Esquivel %!s(int64=11) %!d(string=hai) anos
pai
achega
e342c16c71
Modificáronse 1 ficheiros con 117 adicións e 31 borrados
  1. 117 31
      celery/backends/redis.py

+ 117 - 31
celery/backends/redis.py

@@ -8,6 +8,8 @@
 """
 from __future__ import absolute_import
 
+import re
+
 from kombu.utils import cached_property
 from kombu.utils.url import _parse_url
 
@@ -28,24 +30,106 @@ REDIS_MISSING = """\
 You need to install the redis library in order to use \
 the Redis result store backend."""
 
+class RedisConnectionParams(object):
+    """
+    Bulky class/module that handles taking Redis connection
+    parameters from a url, and mixing them from different default
+    sources. Should probably be used both by kombu and this module...
 
-class RedisBackend(KeyValueStoreBackend):
-    """Redis task result store."""
 
-    #: redis-py client module.
-    redis = redis
+    """
+
+    _tcp_default_params = {
+        #: default Redis server hostname (`localhost`).
+        'host' : 'localhost',
 
-    #: default Redis server hostname (`localhost`).
-    host = 'localhost'
+        #: default Redis server port (6379)
+        'port' : 6379,
 
-    #: default Redis server port (6379)
-    port = 6379
+        #: default Redis db number (0)
+        'db' : 0,
 
-    #: default Redis db number (0)
-    db = 0
+        #: default Redis password (:const:`None`)
+        'password' : None,
 
-    #: default Redis password (:const:`None`)
-    password = None
+        'connection_class': redis.Connection
+    }
+
+    _unix_default_params = {
+        #: default Redis db number (0)
+        'db' : 0,
+
+        #: default Redis password (:const:`None`)
+        'password' : None,
+
+        'connection_class': redis.UnixDomainSocketConnection
+    }
+
+    @staticmethod
+    def prepare_connection_params(given_params, default_params_1=None, default_params_2=None):
+        """ Creates a dictionary with connection parameters, where a key in 'given_params' has greater
+            priority, default_params_1 has lower, and default_params_2 has the lowest. In all the cases,
+            a key not present or with value None will trigger a lookup in the following level.
+        """
+        assert isinstance(given_params, dict)
+        if default_params_1 is None: default_params_1 = {}
+
+        connection_class = given_params.get('connection_class')
+        if connection_class and connection_class is redis.UnixDomainSocketConnection:
+            default_params = RedisConnectionParams._unix_default_params
+        else:
+            default_params = RedisConnectionParams._tcp_default_params
+
+        if default_params_2 is None:
+            default_params_2 = default_params
+
+        result = given_params.copy()
+        for key in default_params.keys():
+            if key not in result:
+                possible_value = default_params_1.get(key) or default_params_2.get(key)
+                if possible_value is not None:
+                    result[key] = possible_value
+        return result
+
+    @staticmethod
+    def connparams_from_url(url):
+        scheme, host, port, user, password, path, query = _parse_url(url)
+
+        connparams = {}
+        if host: connparams['host'] = host
+        if port: connparams['port'] = port
+        if user: connparams['user'] = user
+        if password: connparams['password'] = password
+
+        if query and 'virtual_host' in query:
+            db_no = query['virtual_host']
+            del query['virtual_host']
+            query['db'] = int(db_no)
+
+        if scheme == 'socket':
+            # Use 'path' as path to the socket... in this case
+            # the database number should be given in 'query'
+            connparams.update({
+                'connection_class': redis.UnixDomainSocketConnection,
+                'path': '/' + path})
+            connparams.pop('host', None)
+            connparams.pop('port', None)
+        else:
+            #  Use 'path' to deduce a database number
+            maybe_vhost = re.search(r'/(\d+)/?$', path)
+            if maybe_vhost:
+                db = int(maybe_vhost.group(1))
+                connparams['db'] = db
+        # Query parameters override other parameters
+        connparams.update(query)
+        return connparams
+
+
+class RedisBackend(KeyValueStoreBackend):
+    """Redis task result store."""
+
+    #: redis-py client module.
+    redis = redis
 
     #: Maximium number of connections in the pool.
     max_connections = None
@@ -69,16 +153,22 @@ class RedisBackend(KeyValueStoreBackend):
                 except KeyError:
                     pass
         if host and '://' in host:
-            url, host = host, None
-        self.url = url
-        uhost = uport = upass = udb = None
-        if url:
-            _, uhost, uport, _, upass, udb, _ = _parse_url(url)
-            udb = udb.strip('/') if udb else 0
-        self.host = uhost or host or _get('HOST') or self.host
-        self.port = int(uport or port or _get('PORT') or self.port)
-        self.db = udb or db or _get('DB') or self.db
-        self.password = upass or password or _get('PASSWORD') or self.password
+            url = host
+            host = None
+
+        old_config_port = _get('PORT')
+
+        connparams = RedisConnectionParams.prepare_connection_params(
+            RedisConnectionParams.connparams_from_url(url) if url else {},
+            {
+                'host': _get('HOST'),
+                'port': int(old_config_port) if old_config_port else None,
+                'db': _get('DB'),
+                'password': _get('PASSWORD')
+            }
+        )
+
+        self.connparams = connparams
         self.expires = self.prepare_expires(expires, type=int)
         self.max_connections = (max_connections
                                 or _get('MAX_CONNECTIONS')
@@ -109,17 +199,13 @@ class RedisBackend(KeyValueStoreBackend):
 
     @cached_property
     def client(self):
-        pool = self.redis.ConnectionPool(host=self.host, port=self.port,
-                                         db=self.db, password=self.password,
-                                         max_connections=self.max_connections)
+        pool = self.redis.ConnectionPool(max_connections=self.max_connections,
+                                         **self.connparams)
         return self.redis.Redis(connection_pool=pool)
 
     def __reduce__(self, args=(), kwargs={}):
         kwargs.update(
-            dict(host=self.host,
-                 port=self.port,
-                 db=self.db,
-                 password=self.password,
-                 expires=self.expires,
-                 max_connections=self.max_connections))
+            dict(expires=self.expires,
+                 max_connections=self.max_connections, **self.connparams))
         return super(RedisBackend, self).__reduce__(args, kwargs)
+