فهرست منبع

99% coverage for celery.utils.datastructures

Ask Solem 12 سال پیش
والد
کامیت
b21a002fea
2فایلهای تغییر یافته به همراه167 افزوده شده و 37 حذف شده
  1. 122 4
      celery/tests/utils/test_datastructures.py
  2. 45 33
      celery/utils/datastructures.py

+ 122 - 4
celery/tests/utils/test_datastructures.py

@@ -1,6 +1,10 @@
 from __future__ import absolute_import
 
+import pickle
+
 from billiard.einfo import ExceptionInfo
+from mock import Mock, patch
+from time import time
 
 from celery.five import items
 from celery.utils.datastructures import (
@@ -20,7 +24,7 @@ class Object(object):
 
 class test_DictAttribute(Case):
 
-    def test_get_set(self):
+    def test_get_set_keys_values_items(self):
         x = DictAttribute(Object())
         x['foo'] = 'The quick brown fox'
         self.assertEqual(x['foo'], 'The quick brown fox')
@@ -29,6 +33,14 @@ class test_DictAttribute(Case):
         self.assertIsNone(x.get('bar'))
         with self.assertRaises(KeyError):
             x['bar']
+        x.foo = 'The quick yellow fox'
+        self.assertEqual(x['foo'], 'The quick yellow fox')
+        self.assertIn(
+            ('foo', 'The quick yellow fox'),
+            list(x.items()),
+        )
+        self.assertIn('foo', list(x.keys()))
+        self.assertIn('The quick yellow fox', x.values())
 
     def test_setdefault(self):
         x = DictAttribute(Object())
@@ -94,6 +106,37 @@ class test_ConfigurationView(Case):
             list(self.view.values()),
             list(expected.values()),
         )
+        self.assertIn('changed_key', list(self.view.keys()))
+        self.assertIn(2, list(self.view.values()))
+        self.assertIn(('both', 2), list(self.view.items()))
+
+    def test_add_defaults_dict(self):
+        defaults = {'foo': 10}
+        self.view.add_defaults(defaults)
+        self.assertEqual(self.view.foo, 10)
+
+    def test_add_defaults_object(self):
+        defaults = Object()
+        defaults.foo = 10
+        self.view.add_defaults(defaults)
+        self.assertEqual(self.view.foo, 10)
+
+    def test_clear(self):
+        self.view.clear()
+        self.assertEqual(self.view.both, 1)
+        self.assertNotIn('changed_key', self.view)
+
+    def test_bool(self):
+        self.assertTrue(bool(self.view))
+        self.view._order[:] = []
+        self.assertFalse(bool(self.view))
+
+    def test_len(self):
+        self.assertEqual(len(self.view), 3)
+        self.view.KEY = 33
+        self.assertEqual(len(self.view), 4)
+        self.view.clear()
+        self.assertEqual(len(self.view), 2)
 
     def test_isa_mapping(self):
         from collections import Mapping
@@ -136,14 +179,61 @@ class test_LimitedSet(Case):
             self.assertIn(n, s)
         self.assertNotIn('foo', s)
 
-    def test_iter(self):
+    def test_purge(self):
+        s = LimitedSet(maxlen=None)
+        [s.add(i) for i in range(10)]
+        s.maxlen = 2
+        s.purge(1)
+        self.assertEqual(len(s), 9)
+        s.purge(None)
+        self.assertEqual(len(s), 2)
+
+        # expired
+        s = LimitedSet(maxlen=None, expires=1)
+        [s.add(i) for i in range(10)]
+        s.maxlen = 2
+        s.purge(1, now=lambda: time() + 100)
+        self.assertEqual(len(s), 9)
+        s.purge(None, now=lambda: time() + 100)
+        self.assertEqual(len(s), 2)
+
+        # not expired
+        s = LimitedSet(maxlen=None, expires=1)
+        [s.add(i) for i in range(10)]
+        s.maxlen = 2
+        s.purge(1, now=lambda: time() - 100)
+        self.assertEqual(len(s), 10)
+        s.purge(None, now=lambda: time() - 100)
+        self.assertEqual(len(s), 10)
+
+        s = LimitedSet(maxlen=None)
+        [s.add(i) for i in range(10)]
+        s.maxlen = 2
+        with patch('celery.utils.datastructures.heappop') as hp:
+            hp.side_effect = IndexError()
+            s.purge()
+            hp.assert_called_with(s._heap)
+        with patch('celery.utils.datastructures.heappop') as hp:
+            s._data = dict((i * 2, i * 2) for i in range(10))
+            s.purge()
+            self.assertEqual(hp.call_count, 10)
+
+    def test_pickleable(self):
         s = LimitedSet(maxlen=2)
-        items = 'foo', 'bar'
+        s.add('foo')
+        s.add('bar')
+        self.assertEqual(pickle.loads(pickle.dumps(s)), s)
+
+    def test_iter(self):
+        s = LimitedSet(maxlen=3)
+        items = ['foo', 'bar', 'baz', 'xaz']
         for item in items:
             s.add(item)
         l = list(iter(s))
-        for item in items:
+        for item in items[1:]:
             self.assertIn(item, l)
+        self.assertNotIn('foo', l)
+        self.assertListEqual(l, items[1:], 'order by insertion time')
 
     def test_repr(self):
         s = LimitedSet(maxlen=2)
@@ -152,6 +242,13 @@ class test_LimitedSet(Case):
             s.add(item)
         self.assertIn('LimitedSet(', repr(s))
 
+    def test_discard(self):
+        s = LimitedSet(maxlen=2)
+        s.add('foo')
+        s.discard('foo')
+        self.assertNotIn('foo', s)
+        s.discard('foo')
+
     def test_clear(self):
         s = LimitedSet(maxlen=2)
         s.add('foo')
@@ -220,12 +317,33 @@ class test_DependencyGraph(Case):
             ['C', 'D'],
         )
 
+    def test_connect(self):
+        x, y = self.graph1(), self.graph1()
+        x.connect(y)
+
+    def test_valency_of_when_missing(self):
+        x = self.graph1()
+        self.assertEqual(x.valency_of('foobarbaz'), 0)
+
+    def test_format(self):
+        x = self.graph1()
+        x.formatter = Mock()
+        obj = Mock()
+        self.assertTrue(x.format(obj))
+        x.formatter.assert_called_with(obj)
+        x.formatter = None
+        self.assertIs(x.format(obj), obj)
+
     def test_items(self):
         self.assertDictEqual(
             dict(items(self.graph1())),
             {'A': [], 'B': [], 'C': ['A'], 'D': ['C', 'B']},
         )
 
+    def test_repr_node(self):
+        x = self.graph1()
+        self.assertTrue(x.repr_node('fasdswewqewq'))
+
     def test_to_dot(self):
         s = WhateverIO()
         self.graph1().to_dot(s)

+ 45 - 33
celery/utils/datastructures.py

@@ -401,6 +401,11 @@ class DictAttribute(object):
             yield key, getattr(self.obj, key)
     iteritems = _iterate_items
 
+    def _iterate_values(self):
+        for key in self._iterate_keys():
+            yield getattr(self.obj, key)
+    itervalues = _iterate_values
+
     if sys.version_info[0] == 3:  # pragma: no cover
         items = _iterate_items
         keys = _iterate_keys
@@ -411,6 +416,9 @@ class DictAttribute(object):
 
         def items(self):
             return list(self._iterate_items())
+
+        def values(self):
+            return list(self._iterate_values())
 MutableMapping.register(DictAttribute)
 
 
@@ -479,6 +487,7 @@ class ConfigurationView(AttributeDictMixin):
 
     def __bool__(self):
         return any(self._order)
+    __nonzero__  = __bool__  # Py2
 
     def __repr__(self):
         return repr(dict(items(self)))
@@ -548,26 +557,15 @@ class LimitedSet(object):
         self.__len__ = self._data.__len__
         self.__contains__ = self._data.__contains__
 
-    def __iter__(self):
-        return iter(self._data)
-
-    def __len__(self):
-        return len(self._data)
-
-    def __contains__(self, key):
-        return key in self._data
-
-    def add(self, value):
+    def add(self, value, now=time.time):
         """Add a new member."""
-        self.purge(1)
-        now = time.time()
-        self._data[value] = now
-        heappush(self._heap, (now, value))
-
-    def __reduce__(self):
-        return self.__class__, (
-            self.maxlen, self.expires, self._data, self._heap,
-        )
+        # offset is there to modify the length of the list,
+        # this way we can expire an item before inserting the value,
+        # and it will end up in correct order.
+        self.purge(1, offset=1)
+        inserted = now()
+        self._data[value] = inserted
+        heappush(self._heap, (inserted, value))
 
     def clear(self):
         """Remove all members"""
@@ -587,24 +585,27 @@ class LimitedSet(object):
         self._data.pop(value, None)
     pop_value = discard  # XXX compat
 
-    def _expire_item(self):
-        """Hunt down and remove an expired item."""
-        self.purge(1)
-
-    def purge(self, limit=None):
+    def purge(self, limit=None, offset=0, now=time.time):
+        """Purge expired items."""
         H, maxlen = self._heap, self.maxlen
         if not maxlen:
             return
+
+        # If the data/heap gets corrupted and limit is None
+        # this will go into an infinite loop, so limit must
+        # have a value to guard the loop.
+        limit = len(self) + offset if limit is None else limit
+
         i = 0
-        while len(self) >= maxlen:
-            if limit and i > limit:
+        while len(self) + offset > maxlen:
+            if i >= limit:
                 break
             try:
                 item = heappop(H)
             except IndexError:
                 break
             if self.expires:
-                if time.time() < item[0] + self.expires:
+                if now() < item[0] + self.expires:
                     heappush(H, item)
                     break
             try:
@@ -625,11 +626,22 @@ class LimitedSet(object):
     def as_dict(self):
         return self._data
 
+    def __eq__(self, other):
+        return self._heap == other._heap
+
     def __repr__(self):
-        return 'LimitedSet(%s)' % (repr(list(self._data))[:100], )
+        return 'LimitedSet({0})'.format(len(self))
+
+    def __iter__(self):
+        return (item[1] for item in self._heap)
 
-    @property
-    def first(self):
-        """Get the oldest member."""
-        return self._heap[0][1]
-MutableSet.register(LimitedSet)
+    def __len__(self):
+        return len(self._heap)
+
+    def __contains__(self, key):
+        return key in self._data
+
+    def __reduce__(self):
+        return self.__class__, (
+            self.maxlen, self.expires, self._data, self._heap,
+        )