Browse Source

Cosmetics for #1724

Ask Solem 11 years ago
parent
commit
6574ff18b1

+ 65 - 116
celery/backends/redis.py

@@ -8,12 +8,12 @@
 """
 from __future__ import absolute_import
 
-import re
-
 from kombu.utils import cached_property
 from kombu.utils.url import _parse_url
 
 from celery.exceptions import ImproperlyConfigured
+from celery.five import string_t
+from celery.utils import deprecated_property
 
 from .base import KeyValueStoreBackend
 
@@ -30,99 +30,12 @@ REDIS_MISSING = """\
 You need to install the redis library in order to use \
 the Redis result store backend."""
 
-class RedisConnectionParams(object):
-    """
-    Bulky class/module that handles taking Redis connection
-    parameters from a url, and mixing them from different default
-    sources. Should probably be used both by kombu and this module...
-
-
-    """
-
-    _tcp_default_params = {
-        #: default Redis server hostname (`localhost`).
-        'host' : 'localhost',
-
-        #: default Redis server port (6379)
-        'port' : 6379,
-
-        #: default Redis db number (0)
-        'db' : 0,
-
-        #: default Redis password (:const:`None`)
-        'password' : None,
-
-        'connection_class': redis.Connection
-    }
-
-    _unix_default_params = {
-        #: default Redis db number (0)
-        'db' : 0,
-
-        #: default Redis password (:const:`None`)
-        'password' : None,
-
-        'connection_class': redis.UnixDomainSocketConnection
-    }
-
-    @staticmethod
-    def prepare_connection_params(given_params, default_params_1=None, default_params_2=None):
-        """ Creates a dictionary with connection parameters, where a key in 'given_params' has greater
-            priority, default_params_1 has lower, and default_params_2 has the lowest. In all the cases,
-            a key not present or with value None will trigger a lookup in the following level.
-        """
-        assert isinstance(given_params, dict)
-        if default_params_1 is None: default_params_1 = {}
-
-        connection_class = given_params.get('connection_class')
-        if connection_class and connection_class is redis.UnixDomainSocketConnection:
-            default_params = RedisConnectionParams._unix_default_params
-        else:
-            default_params = RedisConnectionParams._tcp_default_params
-
-        if default_params_2 is None:
-            default_params_2 = default_params
-
-        result = given_params.copy()
-        for key in default_params.keys():
-            if key not in result:
-                possible_value = default_params_1.get(key) or default_params_2.get(key)
-                if possible_value is not None:
-                    result[key] = possible_value
-        return result
-
-    @staticmethod
-    def connparams_from_url(url):
-        scheme, host, port, user, password, path, query = _parse_url(url)
-
-        connparams = {}
-        if host: connparams['host'] = host
-        if port: connparams['port'] = port
-        if user: connparams['user'] = user
-        if password: connparams['password'] = password
-
-        if query and 'virtual_host' in query:
-            db_no = query['virtual_host']
-            del query['virtual_host']
-            query['db'] = int(db_no)
-
-        if scheme == 'socket':
-            # Use 'path' as path to the socket... in this case
-            # the database number should be given in 'query'
-            connparams.update({
-                'connection_class': redis.UnixDomainSocketConnection,
-                'path': '/' + path})
-            connparams.pop('host', None)
-            connparams.pop('port', None)
-        else:
-            #  Use 'path' to deduce a database number
-            maybe_vhost = re.search(r'/(\d+)/?$', path)
-            if maybe_vhost:
-                db = int(maybe_vhost.group(1))
-                connparams['db'] = db
-        # Query parameters override other parameters
-        connparams.update(query)
-        return connparams
+default_params = {
+    'host': 'localhost',
+    'port': 6379,
+    'db': 0,
+    'password': None,
+}
 
 
 class RedisBackend(KeyValueStoreBackend):
@@ -156,23 +69,46 @@ class RedisBackend(KeyValueStoreBackend):
             url = host
             host = None
 
-        old_config_port = _get('PORT')
-
-        connparams = RedisConnectionParams.prepare_connection_params(
-            RedisConnectionParams.connparams_from_url(url) if url else {},
-            {
-                'host': _get('HOST'),
-                'port': int(old_config_port) if old_config_port else None,
-                'db': _get('DB'),
-                'password': _get('PASSWORD')
-            }
+        self.max_connections = (
+            max_connections or _get('MAX_CONNECTIONS') or self.max_connections
         )
 
-        self.connparams = connparams
+        self.connparams = {
+            'host': _get('HOST') or 'localhost',
+            'port': _get('PORT') or 6379,
+            'db': _get('DB') or 0,
+            'password': _get('PASSWORD'),
+            'max_connections': max_connections,
+        }
+        if url:
+            self.connparams = self._params_from_url(url, self.connparams)
+        self.url = url
         self.expires = self.prepare_expires(expires, type=int)
-        self.max_connections = (max_connections
-                                or _get('MAX_CONNECTIONS')
-                                or self.max_connections)
+
+    def _params_from_url(self, url, defaults):
+        scheme, host, port, user, password, path, query = _parse_url(url)
+        connparams = dict(
+            defaults,
+            host=host, port=port, user=user, password=password,
+            db=int(query.pop('virtual_host', None) or 0),
+        )
+
+        if scheme == 'socket':
+            # Use 'path' as path to the socket... in this case
+            # the database number should be given in 'query'
+            connparams.update({
+                'connection_class': self.redis.UnixDomainSocketConnection,
+                'path': '/' + path,
+            })
+            connparams.pop('host', None)
+            connparams.pop('port', None)
+        else:
+            path = path.strip('/') if isinstance(path, string_t) else path
+            if path:
+                connparams['db'] = int(path)
+        # Query parameters override other parameters
+        connparams.update(query)
+        return connparams
 
     def get(self, key):
         return self.client.get(key)
@@ -199,13 +135,26 @@ class RedisBackend(KeyValueStoreBackend):
 
     @cached_property
     def client(self):
-        pool = self.redis.ConnectionPool(max_connections=self.max_connections,
-                                         **self.connparams)
-        return self.redis.Redis(connection_pool=pool)
+        return self.redis.Redis(
+            connection_pool=self.redis.ConnectionPool(**self.connparams))
 
     def __reduce__(self, args=(), kwargs={}):
-        kwargs.update(
-            dict(expires=self.expires,
-                 max_connections=self.max_connections, **self.connparams))
-        return super(RedisBackend, self).__reduce__(args, kwargs)
+        return super(RedisBackend, self).__reduce__(
+            (self.url, ), {'expires': self.expires},
+        )
+
+    @deprecated_property(3.2, 3.3)
+    def host(self):
+        return self.connparams['host']
+
+    @deprecated_property(3.2, 3.3)
+    def port(self):
+        return self.connparams['port']
+
+    @deprecated_property(3.2, 3.3)
+    def db(self):
+        return self.connparams['db']
 
+    @deprecated_property(3.2, 3.3)
+    def password(self):
+        return self.connparams['password']

+ 16 - 16
celery/events/state.py

@@ -174,25 +174,25 @@ class Worker(object):
     def id(self):
         return '{0.hostname}.{0.pid}'.format(self)
 
-    @deprecated('3.2' '3.3')
+    @deprecated(3.2, 3.3)
     def update_heartbeat(self, received, timestamp):
         self.event(None, timestamp, received)
 
-    @deprecated('3.2', '3.3')
+    @deprecated(3.2, 3.3)
     def on_online(self, timestamp=None, local_received=None, **fields):
         self.event('online', timestamp, local_received, fields)
 
-    @deprecated('3.2', '3.3')
+    @deprecated(3.2, 3.3)
     def on_offline(self, timestamp=None, local_received=None, **fields):
         self.event('offline', timestamp, local_received, fields)
 
-    @deprecated('3.2', '3.3')
+    @deprecated(3.2, 3.3)
     def on_heartbeat(self, timestamp=None, local_received=None, **fields):
         self.event('heartbeat', timestamp, local_received, fields)
 
     @class_property
     def _defaults(cls):
-        """Deprecated, to be removed in 3.2"""
+        """Deprecated, to be removed in 3.3"""
         source = cls()
         return dict((k, getattr(source, k)) for k in cls._fields)
 
@@ -310,44 +310,44 @@ class Task(object):
     def ready(self):
         return self.state in states.READY_STATES
 
-    @deprecated('3.2', '3.3')
+    @deprecated(3.2, 3.3)
     def on_sent(self, timestamp=None, **fields):
         self.event('sent', timestamp, fields)
 
-    @deprecated('3.2', '3.3')
+    @deprecated(3.2, 3.3)
     def on_received(self, timestamp=None, **fields):
         self.event('received', timestamp, fields)
 
-    @deprecated('3.2', '3.3')
+    @deprecated(3.2, 3.3)
     def on_started(self, timestamp=None, **fields):
         self.event('started', timestamp, fields)
 
-    @deprecated('3.2', '3.3')
+    @deprecated(3.2, 3.3)
     def on_failed(self, timestamp=None, **fields):
         self.event('failed', timestamp, fields)
 
-    @deprecated('3.2', '3.3')
+    @deprecated(3.2, 3.3)
     def on_retried(self, timestamp=None, **fields):
         self.event('retried', timestamp, fields)
 
-    @deprecated('3.2' '3.3')
+    @deprecated(3.2, 3.3)
     def on_succeeded(self, timestamp=None, **fields):
         self.event('succeeded', timestamp, fields)
 
-    @deprecated('3.2', '3.3')
+    @deprecated(3.2, 3.3)
     def on_revoked(self, timestamp=None, **fields):
         self.event('revoked', timestamp, fields)
 
-    @deprecated('3.2', '3.3')
+    @deprecated(3.2, 3.3)
     def on_unknown_event(self, shortype, timestamp=None, **fields):
         self.event(shortype, timestamp, fields)
 
-    @deprecated('3.2', '3.3')
+    @deprecated(3.2, 3.3)
     def update(self, state, timestamp, fields,
                _state=states.state, RETRY=states.RETRY):
         return self.event(state, timestamp, None, fields)
 
-    @deprecated('3.2', '3.3')
+    @deprecated(3.2, 3.3)
     def merge(self, state, timestamp, fields):
         keep = self.merge_rules.get(state)
         if keep is not None:
@@ -357,7 +357,7 @@ class Task(object):
 
     @class_property
     def _defaults(cls):
-        """Deprecated, to be removed in 3.2."""
+        """Deprecated, to be removed in 3.3."""
         source = cls()
         return dict((k, getattr(source, k)) for k in source._fields)
 

+ 2 - 2
celery/loaders/__init__.py

@@ -25,13 +25,13 @@ def get_loader_cls(loader):
     return symbol_by_name(loader, LOADER_ALIASES, imp=import_from_cwd)
 
 
-@deprecated(deprecation='2.5', removal='4.0',
+@deprecated(deprecation=2.5, removal=4.0,
             alternative='celery.current_app.loader')
 def current_loader():
     return current_app.loader
 
 
-@deprecated(deprecation='2.5', removal='4.0',
+@deprecated(deprecation=2.5, removal=4.0,
             alternative='celery.current_app.conf')
 def load_settings():
     return current_app.conf

+ 1 - 1
celery/tests/backends/test_redis.py

@@ -101,7 +101,7 @@ class test_RedisBackend(AppCase):
     def test_url(self):
         x = self.MockBackend('redis://foobar//1', app=self.app)
         self.assertEqual(x.host, 'foobar')
-        self.assertEqual(x.db, '1')
+        self.assertEqual(x.db, 1)
 
     def test_conf_raises_KeyError(self):
         self.app.conf = AttributeDict({

+ 49 - 0
celery/utils/__init__.py

@@ -113,6 +113,55 @@ def deprecated(deprecation=None, removal=None,
     return _inner
 
 
+def deprecated_property(deprecation=None, removal=None,
+                        alternative=None, description=None):
+    def _inner(fun):
+        return _deprecated_property(
+            fun, deprecation=deprecation, removal=removal,
+            alternative=alternative, description=description or fun.__name__)
+    return _inner
+
+
+class _deprecated_property(object):
+
+    def __init__(self, fget=None, fset=None, fdel=None, doc=None, **depreinfo):
+        self.__get = fget
+        self.__set = fset
+        self.__del = fdel
+        self.__name__, self.__module__, self.__doc__ = (
+            fget.__name__, fget.__module__, fget.__doc__,
+        )
+        self.depreinfo = depreinfo
+
+    def __get__(self, obj, type=None):
+        if obj is None:
+            return self
+        warn_deprecated(**self.depreinfo)
+        return self.__get(obj)
+
+    def __set__(self, obj, value):
+        if obj is None:
+            return self
+        if self.__set is None:
+            raise AttributeError('cannot set attribute')
+        warn_deprecated(**self.depreinfo)
+        self.__set(obj, value)
+
+    def __delete__(self, obj):
+        if obj is None:
+            return self
+        if self.__del is None:
+            raise AttributeError('cannot delete attribute')
+        warn_deprecated(**self.depreinfo)
+        self.__del(obj)
+
+    def setter(self, fset):
+        return self.__class__(self.__get, fset, self.__del, **self.depreinfo)
+
+    def deleter(self, fdel):
+        return self.__class__(self.__get, self.__set, fdel, **self.depreinfo)
+
+
 def lpmerge(L, R):
     """In place left precedent dictionary merge.