Jelajahi Sumber

More coverage celery.datastructures

Ask Solem 13 tahun lalu
induk
melakukan
8bdbf2ec7a

+ 17 - 14
celery/datastructures.py

@@ -22,6 +22,7 @@ from threading import RLock
 
 from kombu.utils.limits import TokenBucket  # noqa
 
+from .utils import uniq
 from .utils.compat import UserDict, OrderedDict
 
 
@@ -262,7 +263,7 @@ class DictAttribute(object):
         return vars(self.obj).iteritems()
     iteritems = _iterate_items
 
-    if sys.version_info >= (3, 0):
+    if sys.version_info >= (3, 0):  # pragma: no cover
         items = _iterate_items
     else:
 
@@ -333,15 +334,15 @@ class ConfigurationView(AttributeDictMixin):
         return chain(*[op(d) for d in reversed(self._order)])
 
     def _iterate_keys(self):
-        return self._iter(lambda d: d.iterkeys())
+        return uniq(self._iter(lambda d: d.iterkeys()))
     iterkeys = _iterate_keys
 
     def _iterate_items(self):
-        return self._iter(lambda d: d.iteritems())
+        return ((key, self[key]) for key in self)
     iteritems = _iterate_items
 
     def _iterate_values(self):
-        return self._iter(lambda d: d.itervalues())
+        return (self[key] for key in self)
     itervalues = _iterate_values
 
     def keys(self):
@@ -370,12 +371,13 @@ class _Frame(object):
             "__name__": frame.f_globals.get("__name__"),
             "__loader__": frame.f_globals.get("__loader__"),
         }
+        self.f_locals = fl = {}
+        try:
+            fl["__traceback_hide__"] = frame.f_locals["__traceback_hide__"]
+        except KeyError:
+            pass
         self.f_code = self.Code(frame.f_code)
-        self.f_locals = {}
-        if '__traceback_hide__' in frame.f_locals:
-            self.f_locals['__traceback_hide__'] = frame.f_locals['__traceback_hide__']
         self.f_lineno = frame.f_lineno
-        self.f_code = _Code(code=frame.f_code)
 
 
 class Traceback(object):
@@ -477,7 +479,8 @@ class LimitedSet(object):
         if isinstance(other, self.__class__):
             self._data.update(other._data)
         else:
-            self._data.update(other)
+            for obj in other:
+                self.add(obj)
 
     def as_dict(self):
         return self._data
@@ -519,7 +522,7 @@ class LRUCache(UserDict):
     def __getitem__(self, key):
         with self.mutex:
             value = self[key] = self.data.pop(key)
-            return value
+        return value
 
     def keys(self):
         # userdict.keys in py3k calls __getitem__
@@ -542,7 +545,7 @@ class LRUCache(UserDict):
         return self.data.iterkeys()
 
     def _iterate_items(self):
-        for k in self.data:
+        for k in self:
             try:
                 yield (k, self.data[k])
             except KeyError:
@@ -550,10 +553,10 @@ class LRUCache(UserDict):
     iteritems = _iterate_items
 
     def _iterate_values(self):
-        for k in self.data:
+        for k in self:
             try:
                 yield self.data[k]
-            except KeyError:
+            except KeyError:  # pragma: no cover
                 pass
     itervalues = _iterate_values
 
@@ -563,4 +566,4 @@ class LRUCache(UserDict):
             # integer as long as it exists and we can cast it
             newval = int(self.data.pop(key)) + delta
             self[key] = str(newval)
-            return newval
+        return newval

+ 1 - 2
celery/tests/test_concurrency/test_concurrency_eventlet.py

@@ -29,8 +29,7 @@ class test_eventlet_patch(EventletCase):
         prev_eventlet = sys.modules.pop("celery.concurrency.eventlet", None)
         os.environ.pop("EVENTLET_NOPATCH")
         try:
-            from celery.concurrency import eventlet
-            self.assertTrue(eventlet)
+            import celery.concurrency.eventlet
             self.assertTrue(monkey_patched)
         finally:
             sys.modules["celery.concurrency.eventlet"] = prev_eventlet

+ 94 - 5
celery/tests/test_utils/test_datastructures.py

@@ -3,11 +3,11 @@ from __future__ import with_statement
 
 import sys
 
-from celery.datastructures import ExceptionInfo, LRUCache
-from celery.datastructures import LimitedSet
-from celery.datastructures import AttributeDict, DictAttribute
-from celery.datastructures import ConfigurationView
+from celery.datastructures import (ExceptionInfo, LRUCache, LimitedSet,
+                                   AttributeDict, DictAttribute,
+                                   ConfigurationView, DependencyGraph)
 from celery.tests.utils import unittest
+from celery.tests.utils import WhateverIO
 
 
 class Object(object):
@@ -37,13 +37,15 @@ class test_DictAttribute(unittest.TestCase):
         self.assertIn("foo", x)
         self.assertNotIn("bar", x)
 
-    def test_iteritems(self):
+    def test_items(self):
         obj = Object()
         obj.attr1 = 1
         x = DictAttribute(obj)
         x["attr2"] = 2
         self.assertDictEqual(dict(x.iteritems()),
                              dict(attr1=1, attr2=2))
+        self.assertDictEqual(dict(x.items()),
+                             dict(attr1=1, attr2=2))
 
 
 class test_ConfigurationView(unittest.TestCase):
@@ -58,6 +60,17 @@ class test_ConfigurationView(unittest.TestCase):
         self.assertEqual(self.view.setdefault("both", 36), 2)
         self.assertEqual(self.view.setdefault("new", 36), 36)
 
+    def test_get(self):
+        self.assertEqual(self.view.get("both"), 2)
+        sp = object()
+        self.assertIs(self.view.get("nonexisting", sp), sp)
+
+    def test_update(self):
+        changes = dict(self.view.changes)
+        self.view.update(a=1, b=2, c=3)
+        self.assertDictEqual(self.view.changes,
+                             dict(changes, a=1, b=2, c=3))
+
     def test_contains(self):
         self.assertIn("changed_key", self.view)
         self.assertIn("default_key", self.view)
@@ -72,6 +85,10 @@ class test_ConfigurationView(unittest.TestCase):
                     "default_key": 1,
                     "both": 2}
         self.assertDictEqual(dict(self.view.items()), expected)
+        self.assertItemsEqual(list(iter(self.view)),
+                              expected.keys())
+        self.assertItemsEqual(self.view.keys(), expected.keys())
+        self.assertItemsEqual(self.view.values(), expected.values())
 
 
 class test_ExceptionInfo(unittest.TestCase):
@@ -123,6 +140,34 @@ class test_LimitedSet(unittest.TestCase):
             s.add(item)
         self.assertIn("LimitedSet(", repr(s))
 
+    def test_clear(self):
+        s = LimitedSet(maxlen=2)
+        s.add("foo")
+        s.add("bar")
+        self.assertEqual(len(s), 2)
+        s.clear()
+        self.assertFalse(s)
+
+    def test_update(self):
+        s1 = LimitedSet(maxlen=2)
+        s1.add("foo")
+        s1.add("bar")
+
+        s2 = LimitedSet(maxlen=2)
+        s2.update(s1)
+        self.assertItemsEqual(list(s2), ["foo", "bar"])
+
+        s2.update(["bla"])
+        self.assertItemsEqual(list(s2), ["bla", "bar"])
+
+        s2.update(["do", "re"])
+        self.assertItemsEqual(list(s2), ["do", "re"])
+
+    def test_as_dict(self):
+        s = LimitedSet(maxlen=2)
+        s.add("foo")
+        self.assertIsInstance(s.as_dict(), dict)
+
 
 class test_LRUCache(unittest.TestCase):
 
@@ -195,6 +240,11 @@ class test_LRUCache(unittest.TestCase):
     def test_safe_to_remove_while_itervalues(self):
         self.assertSafeIter("itervalues")
 
+    def test_items(self):
+        c = LRUCache()
+        c.update(a=1, b=2, c=3)
+        self.assertTrue(c.items())
+
 
 class test_AttributeDict(unittest.TestCase):
 
@@ -205,3 +255,42 @@ class test_AttributeDict(unittest.TestCase):
             x.bar
         x.bar = "foo"
         self.assertEqual(x["bar"], "foo")
+
+
+class test_DependencyGraph(unittest.TestCase):
+
+    def graph1(self):
+        return DependencyGraph([
+            ("A", []),
+            ("B", []),
+            ("C", ["A"]),
+            ("D", ["C", "B"])
+        ])
+
+    def test_repr(self):
+        self.assertTrue(repr(self.graph1()))
+
+    def test_topsort(self):
+        order = self.graph1().topsort()
+        print("ORDER: %r" % (order, ))
+        # C must start before D
+        self.assertLess(order.index("C"), order.index("D"))
+        # and B must start before D
+        self.assertLess(order.index("B"), order.index("D"))
+        # and A must start before C
+        self.assertLess(order.index("A"), order.index("C"))
+
+    def test_edges(self):
+        self.assertListEqual(list(self.graph1().edges()),
+                             ["C", "D"])
+
+    def test_items(self):
+        self.assertDictEqual(dict(self.graph1().items()),
+                {"A": [], "B": [],
+                 "C": ["A"], "D": ["C", "B"]})
+
+    def test_to_dot(self):
+        s = WhateverIO()
+        self.graph1().to_dot(s)
+        self.assertTrue(s.getvalue())
+

+ 8 - 0
celery/utils/__init__.py

@@ -494,3 +494,11 @@ def reprcall(name, args=(), kwargs=(), sep=', '):
     return "%s(%s%s%s)" % (name, sep.join(map(_safe_repr, args)),
                            (args and kwargs) and sep or "",
                            reprkwargs(kwargs, sep))
+
+
+def uniq(it):
+    seen = set()
+    for obj in it:
+        if obj not in seen:
+            yield obj
+            seen.add(obj)