Browse Source

Cosmetics for #3151

Ask Solem 9 years ago
parent
commit
4ea01be5a0
2 changed files with 32 additions and 50 deletions
  1. 26 33
      celery/backends/consul.py
  2. 6 17
      celery/tests/backends/test_consul.py

+ 26 - 33
celery/backends/consul.py

@@ -1,12 +1,12 @@
 # -*- coding: utf-8 -*-
 """
     celery.backends.consul
-    ~~~~~~~~~~~~~~~~~~~~~
+    ~~~~~~~~~~~~~~~~~~~~~~
 
     Consul result store backend.
 
     - :class:`ConsulBackend` implements KeyValueStoreBackend to store results
-      the key-value store of Consul.
+      in the key-value store of Consul.
 
 """
 from __future__ import absolute_import, unicode_literals
@@ -22,7 +22,7 @@ try:
 except ImportError:
     consul = None
 
-LOGGER = get_logger(__name__)
+logger = get_logger(__name__)
 
 __all__ = ['ConsulBackend']
 
@@ -32,9 +32,7 @@ the Consul result store backend."""
 
 
 class ConsulBackend(KeyValueStoreBackend):
-    """
-    Consul.io K/V store backend for Celery
-    """
+    """Consul.io K/V store backend for Celery."""
     consul = consul
 
     supports_autoexpire = True
@@ -43,37 +41,31 @@ class ConsulBackend(KeyValueStoreBackend):
     consistency = 'consistent'
     path = None
 
-    def __init__(self, url=None, expires=None, **kwargs):
-        super(ConsulBackend, self).__init__(**kwargs)
+    def __init__(self, *args, **kwargs):
+        super(ConsulBackend, self).__init__(*args, **kwargs)
 
         if self.consul is None:
             raise ImproperlyConfigured(CONSUL_MISSING)
 
-        self.url = url
-        self.expires = self.prepare_expires(expires, int)
+        self._init_from_params(**parse_url(self.url))
 
-        params = parse_url(self.url)
-        self.path = params['virtual_host']
-        LOGGER.debug('Setting on Consul client to connect to %s:%d',
-                     params['hostname'], params['port'])
-        self.client = consul.Consul(host=params['hostname'],
-                                    port=params['port'],
+    def _init_from_params(self, hostname, port, virtual_host, **params):
+        logger.debug('Setting on Consul client to connect to %s:%d',
+                     hostname, port)
+        self.path = virtual_host
+        self.client = consul.Consul(host=hostname, port=port,
                                     consistency=self.consistency)
 
     def _key_to_consul_key(self, key):
         if PY3:
-            key = key.decode('utf-8')
-
-        if self.path is not None:
-            return '{0}/{1}'.format(self.path, key)
-
-        return key
+            key = key.encode('utf-8')
+        return key if self.path is None else '{0}/{1}'.format(self.path, key)
 
     def get(self, key):
-        LOGGER.debug('Trying to fetch key %s from Consul',
-                     self._key_to_consul_key(key))
+        key = self._key_to_consul_key(key)
+        logger.debug('Trying to fetch key %s from Consul', key)
         try:
-            _, data = self.client.kv.get(self._key_to_consul_key(key))
+            _, data = self.client.kv.get(key)
             return data['Value']
         except TypeError:
             pass
@@ -92,25 +84,26 @@ class ConsulBackend(KeyValueStoreBackend):
 
         If the session expires it will remove the key so that results
         can auto expire from the K/V store
+
         """
         session_name = key
-
         if PY3:
             session_name = key.decode('utf-8')
+        key = self._key_to_consul_key(key)
 
-        LOGGER.debug('Trying to create Consul session %s with TTL %d',
+        logger.debug('Trying to create Consul session %s with TTL %d',
                      session_name, self.expires)
         session_id = self.client.session.create(name=session_name,
                                                 behavior='delete',
                                                 ttl=self.expires)
-        LOGGER.debug('Created Consul session %s', session_id)
+        logger.debug('Created Consul session %s', session_id)
 
-        LOGGER.debug('Writing key %s to Consul', self._key_to_consul_key(key))
-        return self.client.kv.put(key=self._key_to_consul_key(key),
+        logger.debug('Writing key %s to Consul', key)
+        return self.client.kv.put(key=key,
                                   value=value,
                                   acquire=session_id)
 
     def delete(self, key):
-        LOGGER.debug('Removing key %s from Consul',
-                     self._key_to_consul_key(key))
-        return self.client.kv.delete(self._key_to_consul_key(key))
+        key = self._key_to_consul_key(key)
+        logger.debug('Removing key %s from Consul', key)
+        return self.client.kv.delete(key)

+ 6 - 17
celery/tests/backends/test_consul.py

@@ -3,19 +3,13 @@ from __future__ import absolute_import, unicode_literals
 from celery.tests.case import AppCase, Mock, skip
 from celery.backends.consul import ConsulBackend
 
-try:
-    import consul
-except ImportError:
-    consul = None
-
 
 @skip.unless_module('consul')
 class test_ConsulBackend(AppCase):
 
     def setup(self):
-        if consul is None:
-            raise SkipTest('python-consul is not installed.')
-        self.backend = ConsulBackend(app=self.app)
+        self.backend = ConsulBackend(
+            app=self.app, url='consul://localhost:800')
 
     def test_supports_autoexpire(self):
         self.assertTrue(self.backend.supports_autoexpire)
@@ -24,14 +18,9 @@ class test_ConsulBackend(AppCase):
         self.assertEqual('consistent', self.backend.consistency)
 
     def test_get(self):
-        c = ConsulBackend(app=self.app)
-        c.client = Mock()
-        c.client.kv = Mock()
-        c.client.kv.get = Mock()
+        c = self.backend
         index = 100
         data = {'Key': 'test-consul-1', 'Value': 'mypayload'}
-        r = (index, data)
-        c.client.kv.get.return_value = r
-        i, d = c.get(data['Key'])
-        self.assertEqual(i, 100)
-        self.assertEqual(d['Key'], data['Key'])
+        self.backend.client = Mock(name='c.client')
+        self.backend.client.kv.get.return_value = (index, data)
+        self.assertEqual(self.backend.get(data['Key']), 'mypayload')