Parcourir la source

Utils: ConfigurationView is now a ChainMap

Ask Solem il y a 8 ans
Parent
commit
e3214fc0e0
3 fichiers modifiés avec 144 ajouts et 79 suppressions
  1. 1 1
      celery/app/utils.py
  2. 142 77
      celery/datastructures.py
  3. 1 1
      celery/tests/utils/test_datastructures.py

+ 1 - 1
celery/app/utils.py

@@ -117,7 +117,7 @@ class Settings(ConfigurationView):
     def without_defaults(self):
         """Return the current configuration, but without defaults."""
         # the last stash is the default settings, so just skip that
-        return Settings({}, self._order[:-1])
+        return Settings({}, self.maps[:-1])
 
     def value_set_for(self, key):
         return key in self.without_defaults()

+ 142 - 77
celery/datastructures.py

@@ -448,78 +448,80 @@ class DictAttribute(object):
 MutableMapping.register(DictAttribute)
 
 
-@python_2_unicode_compatible
-class ConfigurationView(AttributeDictMixin):
-    """A view over an applications configuration dictionaries.
-
-    Custom (but older) version of :class:`collections.ChainMap`.
-
-    If the key does not exist in ``changes``, the ``defaults``
-    dictionaries are consulted.
-
-    :param changes:  Dict containing changes to the configuration.
-    :param defaults: List of dictionaries containing the default
-                     configuration.
+class ChainMap(MutableMapping):
 
-    """
     key_t = None
     changes = None
     defaults = None
-    _order = None
+    maps = None
 
-    def __init__(self, changes, defaults=None, key_t=None, prefix=None):
-        defaults = [] if defaults is None else defaults
+    def __init__(self, *maps, **kwargs):
+        maps = list(maps or [{}])
         self.__dict__.update(
-            changes=changes,
-            defaults=defaults,
-            key_t=key_t,
-            _order=[changes] + defaults,
-            prefix=prefix.rstrip('_') + '_' if prefix else prefix,
+            key_t=kwargs.get('key_t'),
+            maps=maps,
+            changes=maps[0],
+            defaults=maps[1:],
         )
 
-    def _to_keys(self, key):
-        prefix = self.prefix
-        if prefix:
-            pkey = prefix + key if not key.startswith(prefix) else key
-            return match_case(pkey, prefix), self._key(key)
-        return self._key(key),
-
-    def _key(self, key):
-        return self.key_t(key) if self.key_t is not None else key
-
     def add_defaults(self, d):
         d = force_mapping(d)
         self.defaults.insert(0, d)
-        self._order.insert(1, d)
+        self.maps.insert(1, d)
 
-    def __getitem__(self, key):
-        keys = self._to_keys(key)
-        for k in keys:
-            for d in self._order:
-                try:
-                    return d[k]
-                except KeyError:
-                    pass
-        if len(keys) > 1:
+    def pop(self, key, *default):
+        try:
+            return self.maps[0].pop(key, *default)
+        except KeyError:
             raise KeyError(
-                'Key not found: {0!r} (with prefix: {0!r})'.format(*keys))
+                'Key not found in the first mapping: {!r}'.format(key))
+
+    def __missing__(self, key):
         raise KeyError(key)
 
+    def _key(self, key):
+        return self.key_t(key) if self.key_t is not None else key
+
+    def __getitem__(self, key):
+        _key = self._key(key)
+        for mapping in self.maps:
+            try:
+                return mapping[_key]
+            except KeyError:
+                pass
+        return self.__missing__(key)
+
     def __setitem__(self, key, value):
         self.changes[self._key(key)] = value
 
-    def first(self, *keys):
-        return first(None, (self.get(key) for key in keys))
+    def __delitem__(self, key):
+        try:
+            del self.changes[self._key(key)]
+        except KeyError:
+            raise KeyError('Key not found in first mapping: {0!r}'.format(key))
+
+    def clear(self):
+        self.changes.clear()
 
     def get(self, key, default=None):
         try:
-            return self[key]
+            return self[self._key(key)]
         except KeyError:
             return default
 
-    def clear(self):
-        """Remove all changes, but keep defaults."""
-        self.changes.clear()
+    def __len__(self):
+        return len(set().union(*self.maps))
+
+    def __iter__(self):
+        return self._iterate_keys()
+
+    def __contains__(self, key):
+        key = self._key(key)
+        return any(key in m for m in self.maps)
+
+    def __bool__(self):
+        return any(self.maps)
+    __nonzero__ = __bool__  # Py2
 
     def setdefault(self, key, default):
         key = self._key(key)
@@ -529,40 +531,25 @@ class ConfigurationView(AttributeDictMixin):
     def update(self, *args, **kwargs):
         return self.changes.update(*args, **kwargs)
 
-    def __contains__(self, key):
-        keys = self._to_keys(key)
-        return any(any(k in m for k in keys) for m in self._order)
-
-    def __bool__(self):
-        return any(self._order)
-    __nonzero__ = __bool__  # Py2
-
     def __repr__(self):
-        return repr(dict(items(self)))
+        return '{0.__class__.__name__}({1})'.format(
+            self, ', '.join(map(repr, self.maps)))
 
-    def __iter__(self):
-        return self._iterate_keys()
+    @classmethod
+    def fromkeys(cls, iterable, *args):
+        """Create a ChainMap with a single dict created from the iterable."""
+        return cls(dict.fromkeys(iterable, *args))
 
-    def __len__(self):
-        # The logic for iterating keys includes uniq(),
-        # so to be safe we count by explicitly iterating
-        return len(set().union(*self._order))
+    def copy(self):
+        """New ChainMap or subclass with a new copy of maps[0] and
+        refs to maps[1:]."""
+        return self.__class__(self.maps[0].copy(), *self.maps[1:])
+    __copy__ = copy  # Py2
 
     def _iter(self, op):
         # defaults must be first in the stream, so values in
-        # changes takes precedence.
-        return chain(*[op(d) for d in reversed(self._order)])
-
-    def swap_with(self, other):
-        changes = other.__dict__['changes']
-        defaults = other.__dict__['defaults']
-        self.__dict__.update(
-            changes=changes,
-            defaults=defaults,
-            key_t=other.__dict__['key_t'],
-            prefix=other.__dict__['prefix'],
-            _order=[changes] + defaults
-        )
+        # changes take precedence.
+        return chain(*[op(d) for d in reversed(self.maps)])
 
     def _iterate_keys(self):
         return uniq(self._iter(lambda d: d.keys()))
@@ -590,7 +577,85 @@ class ConfigurationView(AttributeDictMixin):
 
         def values(self):
             return list(self._iterate_values())
-MutableMapping.register(ConfigurationView)
+
+
+@python_2_unicode_compatible
+class ConfigurationView(ChainMap, AttributeDictMixin):
+    """A view over an applications configuration dictionaries.
+
+    Custom (but older) version of :class:`collections.ChainMap`.
+
+    If the key does not exist in ``changes``, the ``defaults``
+    dictionaries are consulted.
+
+    :param changes:  Dict containing changes to the configuration.
+    :param defaults: List of dictionaries containing the default
+                     configuration.
+
+    """
+
+    def __init__(self, changes, defaults=None, key_t=None, prefix=None):
+        defaults = [] if defaults is None else defaults
+        super(ConfigurationView, self).__init__(
+            changes, *defaults, **{'key_t': key_t})
+        self.__dict__.update(
+            prefix=prefix.rstrip('_') + '_' if prefix else prefix,
+        )
+
+    def _to_keys(self, key):
+        prefix = self.prefix
+        if prefix:
+            pkey = prefix + key if not key.startswith(prefix) else key
+            return match_case(pkey, prefix), key
+        return key,
+
+    def __getitem__(self, key):
+        keys = self._to_keys(key)
+        getitem = super(ConfigurationView, self).__getitem__
+        for k in keys:
+            try:
+                return getitem(k)
+            except KeyError:
+                pass
+        try:
+            # support subclasses implementing __missing__
+            return self.__missing__(key)
+        except KeyError:
+            if len(keys) > 1:
+                raise KeyError(
+                    'Key not found: {0!r} (with prefix: {0!r})'.format(*keys))
+            raise
+
+    def __setitem__(self, key, value):
+        self.changes[self._key(key)] = value
+
+    def first(self, *keys):
+        return first(None, (self.get(key) for key in keys))
+
+    def get(self, key, default=None):
+        try:
+            return self[key]
+        except KeyError:
+            return default
+
+    def clear(self):
+        """Remove all changes, but keep defaults."""
+        self.changes.clear()
+
+    def __contains__(self, key):
+        keys = self._to_keys(key)
+        return any(any(k in m for k in keys) for m in self.maps)
+
+    def swap_with(self, other):
+        changes = other.__dict__['changes']
+        defaults = other.__dict__['defaults']
+        self.__dict__.update(
+            changes=changes,
+            defaults=defaults,
+            key_t=other.__dict__['key_t'],
+            prefix=other.__dict__['prefix'],
+            maps=[changes] + defaults
+        )
 
 
 @python_2_unicode_compatible

+ 1 - 1
celery/tests/utils/test_datastructures.py

@@ -131,7 +131,7 @@ class test_ConfigurationView(Case):
 
     def test_bool(self):
         self.assertTrue(bool(self.view))
-        self.view._order[:] = []
+        self.view.maps[:] = []
         self.assertFalse(bool(self.view))
 
     def test_len(self):