Browse Source

Merge branch 'dsign/master'

Ask Solem 11 years ago
parent
commit
8c675ec6b3

+ 71 - 36
celery/backends/redis.py

@@ -12,6 +12,8 @@ from kombu.utils import cached_property
 from kombu.utils.url import _parse_url
 
 from celery.exceptions import ImproperlyConfigured
+from celery.five import string_t
+from celery.utils import deprecated_property
 
 from .base import KeyValueStoreBackend
 
@@ -28,6 +30,13 @@ REDIS_MISSING = """\
 You need to install the redis library in order to use \
 the Redis result store backend."""
 
+default_params = {
+    'host': 'localhost',
+    'port': 6379,
+    'db': 0,
+    'password': None,
+}
+
 
 class RedisBackend(KeyValueStoreBackend):
     """Redis task result store."""
@@ -35,18 +44,6 @@ class RedisBackend(KeyValueStoreBackend):
     #: redis-py client module.
     redis = redis
 
-    #: default Redis server hostname (`localhost`).
-    host = 'localhost'
-
-    #: default Redis server port (6379)
-    port = 6379
-
-    #: default Redis db number (0)
-    db = 0
-
-    #: default Redis password (:const:`None`)
-    password = None
-
     #: Maximium number of connections in the pool.
     max_connections = None
 
@@ -69,20 +66,49 @@ class RedisBackend(KeyValueStoreBackend):
                 except KeyError:
                     pass
         if host and '://' in host:
-            url, host = host, None
-        self.url = url
-        uhost = uport = upass = udb = None
+            url = host
+            host = None
+
+        self.max_connections = (
+            max_connections or _get('MAX_CONNECTIONS') or self.max_connections
+        )
+
+        self.connparams = {
+            'host': _get('HOST') or 'localhost',
+            'port': _get('PORT') or 6379,
+            'db': _get('DB') or 0,
+            'password': _get('PASSWORD'),
+            'max_connections': max_connections,
+        }
         if url:
-            _, uhost, uport, _, upass, udb, _ = _parse_url(url)
-            udb = udb.strip('/') if udb else 0
-        self.host = uhost or host or _get('HOST') or self.host
-        self.port = int(uport or port or _get('PORT') or self.port)
-        self.db = udb or db or _get('DB') or self.db
-        self.password = upass or password or _get('PASSWORD') or self.password
+            self.connparams = self._params_from_url(url, self.connparams)
+        self.url = url
         self.expires = self.prepare_expires(expires, type=int)
-        self.max_connections = (max_connections
-                                or _get('MAX_CONNECTIONS')
-                                or self.max_connections)
+
+    def _params_from_url(self, url, defaults):
+        scheme, host, port, user, password, path, query = _parse_url(url)
+        connparams = dict(
+            defaults,
+            host=host, port=port, user=user, password=password,
+            db=int(query.pop('virtual_host', None) or 0),
+        )
+
+        if scheme == 'socket':
+            # Use 'path' as path to the socket... in this case
+            # the database number should be given in 'query'
+            connparams.update({
+                'connection_class': self.redis.UnixDomainSocketConnection,
+                'path': '/' + path,
+            })
+            connparams.pop('host', None)
+            connparams.pop('port', None)
+        else:
+            path = path.strip('/') if isinstance(path, string_t) else path
+            if path:
+                connparams['db'] = int(path)
+        # Query parameters override other parameters
+        connparams.update(query)
+        return connparams
 
     def get(self, key):
         return self.client.get(key)
@@ -109,17 +135,26 @@ class RedisBackend(KeyValueStoreBackend):
 
     @cached_property
     def client(self):
-        pool = self.redis.ConnectionPool(host=self.host, port=self.port,
-                                         db=self.db, password=self.password,
-                                         max_connections=self.max_connections)
-        return self.redis.Redis(connection_pool=pool)
+        return self.redis.Redis(
+            connection_pool=self.redis.ConnectionPool(**self.connparams))
 
     def __reduce__(self, args=(), kwargs={}):
-        kwargs.update(
-            dict(host=self.host,
-                 port=self.port,
-                 db=self.db,
-                 password=self.password,
-                 expires=self.expires,
-                 max_connections=self.max_connections))
-        return super(RedisBackend, self).__reduce__(args, kwargs)
+        return super(RedisBackend, self).__reduce__(
+            (self.url, ), {'expires': self.expires},
+        )
+
+    @deprecated_property(3.2, 3.3)
+    def host(self):
+        return self.connparams['host']
+
+    @deprecated_property(3.2, 3.3)
+    def port(self):
+        return self.connparams['port']
+
+    @deprecated_property(3.2, 3.3)
+    def db(self):
+        return self.connparams['db']
+
+    @deprecated_property(3.2, 3.3)
+    def password(self):
+        return self.connparams['password']

+ 16 - 16
celery/events/state.py

@@ -174,25 +174,25 @@ class Worker(object):
     def id(self):
         return '{0.hostname}.{0.pid}'.format(self)
 
-    @deprecated('3.2' '3.3')
+    @deprecated(3.2, 3.3)
     def update_heartbeat(self, received, timestamp):
         self.event(None, timestamp, received)
 
-    @deprecated('3.2', '3.3')
+    @deprecated(3.2, 3.3)
     def on_online(self, timestamp=None, local_received=None, **fields):
         self.event('online', timestamp, local_received, fields)
 
-    @deprecated('3.2', '3.3')
+    @deprecated(3.2, 3.3)
     def on_offline(self, timestamp=None, local_received=None, **fields):
         self.event('offline', timestamp, local_received, fields)
 
-    @deprecated('3.2', '3.3')
+    @deprecated(3.2, 3.3)
     def on_heartbeat(self, timestamp=None, local_received=None, **fields):
         self.event('heartbeat', timestamp, local_received, fields)
 
     @class_property
     def _defaults(cls):
-        """Deprecated, to be removed in 3.2"""
+        """Deprecated, to be removed in 3.3"""
         source = cls()
         return dict((k, getattr(source, k)) for k in cls._fields)
 
@@ -310,44 +310,44 @@ class Task(object):
     def ready(self):
         return self.state in states.READY_STATES
 
-    @deprecated('3.2', '3.3')
+    @deprecated(3.2, 3.3)
     def on_sent(self, timestamp=None, **fields):
         self.event('sent', timestamp, fields)
 
-    @deprecated('3.2', '3.3')
+    @deprecated(3.2, 3.3)
     def on_received(self, timestamp=None, **fields):
         self.event('received', timestamp, fields)
 
-    @deprecated('3.2', '3.3')
+    @deprecated(3.2, 3.3)
     def on_started(self, timestamp=None, **fields):
         self.event('started', timestamp, fields)
 
-    @deprecated('3.2', '3.3')
+    @deprecated(3.2, 3.3)
     def on_failed(self, timestamp=None, **fields):
         self.event('failed', timestamp, fields)
 
-    @deprecated('3.2', '3.3')
+    @deprecated(3.2, 3.3)
     def on_retried(self, timestamp=None, **fields):
         self.event('retried', timestamp, fields)
 
-    @deprecated('3.2' '3.3')
+    @deprecated(3.2, 3.3)
     def on_succeeded(self, timestamp=None, **fields):
         self.event('succeeded', timestamp, fields)
 
-    @deprecated('3.2', '3.3')
+    @deprecated(3.2, 3.3)
     def on_revoked(self, timestamp=None, **fields):
         self.event('revoked', timestamp, fields)
 
-    @deprecated('3.2', '3.3')
+    @deprecated(3.2, 3.3)
     def on_unknown_event(self, shortype, timestamp=None, **fields):
         self.event(shortype, timestamp, fields)
 
-    @deprecated('3.2', '3.3')
+    @deprecated(3.2, 3.3)
     def update(self, state, timestamp, fields,
                _state=states.state, RETRY=states.RETRY):
         return self.event(state, timestamp, None, fields)
 
-    @deprecated('3.2', '3.3')
+    @deprecated(3.2, 3.3)
     def merge(self, state, timestamp, fields):
         keep = self.merge_rules.get(state)
         if keep is not None:
@@ -357,7 +357,7 @@ class Task(object):
 
     @class_property
     def _defaults(cls):
-        """Deprecated, to be removed in 3.2."""
+        """Deprecated, to be removed in 3.3."""
         source = cls()
         return dict((k, getattr(source, k)) for k in source._fields)
 

+ 2 - 2
celery/loaders/__init__.py

@@ -25,13 +25,13 @@ def get_loader_cls(loader):
     return symbol_by_name(loader, LOADER_ALIASES, imp=import_from_cwd)
 
 
-@deprecated(deprecation='2.5', removal='4.0',
+@deprecated(deprecation=2.5, removal=4.0,
             alternative='celery.current_app.loader')
 def current_loader():
     return current_app.loader
 
 
-@deprecated(deprecation='2.5', removal='4.0',
+@deprecated(deprecation=2.5, removal=4.0,
             alternative='celery.current_app.conf')
 def load_settings():
     return current_app.conf

+ 1 - 1
celery/tests/backends/test_redis.py

@@ -101,7 +101,7 @@ class test_RedisBackend(AppCase):
     def test_url(self):
         x = self.MockBackend('redis://foobar//1', app=self.app)
         self.assertEqual(x.host, 'foobar')
-        self.assertEqual(x.db, '1')
+        self.assertEqual(x.db, 1)
 
     def test_conf_raises_KeyError(self):
         self.app.conf = AttributeDict({

+ 49 - 0
celery/utils/__init__.py

@@ -113,6 +113,55 @@ def deprecated(deprecation=None, removal=None,
     return _inner
 
 
+def deprecated_property(deprecation=None, removal=None,
+                        alternative=None, description=None):
+    def _inner(fun):
+        return _deprecated_property(
+            fun, deprecation=deprecation, removal=removal,
+            alternative=alternative, description=description or fun.__name__)
+    return _inner
+
+
+class _deprecated_property(object):
+
+    def __init__(self, fget=None, fset=None, fdel=None, doc=None, **depreinfo):
+        self.__get = fget
+        self.__set = fset
+        self.__del = fdel
+        self.__name__, self.__module__, self.__doc__ = (
+            fget.__name__, fget.__module__, fget.__doc__,
+        )
+        self.depreinfo = depreinfo
+
+    def __get__(self, obj, type=None):
+        if obj is None:
+            return self
+        warn_deprecated(**self.depreinfo)
+        return self.__get(obj)
+
+    def __set__(self, obj, value):
+        if obj is None:
+            return self
+        if self.__set is None:
+            raise AttributeError('cannot set attribute')
+        warn_deprecated(**self.depreinfo)
+        self.__set(obj, value)
+
+    def __delete__(self, obj):
+        if obj is None:
+            return self
+        if self.__del is None:
+            raise AttributeError('cannot delete attribute')
+        warn_deprecated(**self.depreinfo)
+        self.__del(obj)
+
+    def setter(self, fset):
+        return self.__class__(self.__get, fset, self.__del, **self.depreinfo)
+
+    def deleter(self, fdel):
+        return self.__class__(self.__get, self.__set, fdel, **self.depreinfo)
+
+
 def lpmerge(L, R):
     """In place left precedent dictionary merge.