test_datastructures.py 12 KB


  1. from __future__ import absolute_import
  2. import pickle
  3. import sys
  4. from collections import Mapping
  5. from itertools import count
  6. from billiard.einfo import ExceptionInfo
  7. from time import time
  8. from celery.datastructures import (
  9. LimitedSet,
  10. AttributeDict,
  11. DictAttribute,
  12. ConfigurationView,
  13. DependencyGraph,
  14. )
  15. from celery.five import items
  16. from celery.tests.case import Case, Mock, WhateverIO, SkipTest, patch
  17. class Object(object):
  18. pass
  19. class test_DictAttribute(Case):
  20. def test_get_set_keys_values_items(self):
  21. x = DictAttribute(Object())
  22. x['foo'] = 'The quick brown fox'
  23. self.assertEqual(x['foo'], 'The quick brown fox')
  24. self.assertEqual(x['foo'], x.obj.foo)
  25. self.assertEqual(x.get('foo'), 'The quick brown fox')
  26. self.assertIsNone(x.get('bar'))
  27. with self.assertRaises(KeyError):
  28. x['bar']
  29. x.foo = 'The quick yellow fox'
  30. self.assertEqual(x['foo'], 'The quick yellow fox')
  31. self.assertIn(
  32. ('foo', 'The quick yellow fox'),
  33. list(x.items()),
  34. )
  35. self.assertIn('foo', list(x.keys()))
  36. self.assertIn('The quick yellow fox', list(x.values()))
  37. def test_setdefault(self):
  38. x = DictAttribute(Object())
  39. x.setdefault('foo', 'NEW')
  40. self.assertEqual(x['foo'], 'NEW')
  41. x.setdefault('foo', 'XYZ')
  42. self.assertEqual(x['foo'], 'NEW')
  43. def test_contains(self):
  44. x = DictAttribute(Object())
  45. x['foo'] = 1
  46. self.assertIn('foo', x)
  47. self.assertNotIn('bar', x)
  48. def test_items(self):
  49. obj = Object()
  50. obj.attr1 = 1
  51. x = DictAttribute(obj)
  52. x['attr2'] = 2
  53. self.assertEqual(x['attr1'], 1)
  54. self.assertEqual(x['attr2'], 2)
  55. class test_ConfigurationView(Case):
  56. def setUp(self):
  57. self.view = ConfigurationView({'changed_key': 1,
  58. 'both': 2},
  59. [{'default_key': 1,
  60. 'both': 1}])
  61. def test_setdefault(self):
  62. self.view.setdefault('both', 36)
  63. self.assertEqual(self.view['both'], 2)
  64. self.view.setdefault('new', 36)
  65. self.assertEqual(self.view['new'], 36)
  66. def test_get(self):
  67. self.assertEqual(self.view.get('both'), 2)
  68. sp = object()
  69. self.assertIs(self.view.get('nonexisting', sp), sp)
  70. def test_update(self):
  71. changes = dict(self.view.changes)
  72. self.view.update(a=1, b=2, c=3)
  73. self.assertDictEqual(self.view.changes,
  74. dict(changes, a=1, b=2, c=3))
  75. def test_contains(self):
  76. self.assertIn('changed_key', self.view)
  77. self.assertIn('default_key', self.view)
  78. self.assertNotIn('new', self.view)
  79. def test_repr(self):
  80. self.assertIn('changed_key', repr(self.view))
  81. self.assertIn('default_key', repr(self.view))
  82. def test_iter(self):
  83. expected = {'changed_key': 1,
  84. 'default_key': 1,
  85. 'both': 2}
  86. self.assertDictEqual(dict(items(self.view)), expected)
  87. self.assertItemsEqual(list(iter(self.view)),
  88. list(expected.keys()))
  89. self.assertItemsEqual(list(self.view.keys()), list(expected.keys()))
  90. self.assertItemsEqual(
  91. list(self.view.values()),
  92. list(expected.values()),
  93. )
  94. self.assertIn('changed_key', list(self.view.keys()))
  95. self.assertIn(2, list(self.view.values()))
  96. self.assertIn(('both', 2), list(self.view.items()))
  97. def test_add_defaults_dict(self):
  98. defaults = {'foo': 10}
  99. self.view.add_defaults(defaults)
  100. self.assertEqual(self.view.foo, 10)
  101. def test_add_defaults_object(self):
  102. defaults = Object()
  103. defaults.foo = 10
  104. self.view.add_defaults(defaults)
  105. self.assertEqual(self.view.foo, 10)
  106. def test_clear(self):
  107. self.view.clear()
  108. self.assertEqual(self.view.both, 1)
  109. self.assertNotIn('changed_key', self.view)
  110. def test_bool(self):
  111. self.assertTrue(bool(self.view))
  112. self.view._order[:] = []
  113. self.assertFalse(bool(self.view))
  114. def test_len(self):
  115. self.assertEqual(len(self.view), 3)
  116. self.view.KEY = 33
  117. self.assertEqual(len(self.view), 4)
  118. self.view.clear()
  119. self.assertEqual(len(self.view), 2)
  120. def test_isa_mapping(self):
  121. from collections import Mapping
  122. self.assertTrue(issubclass(ConfigurationView, Mapping))
  123. def test_isa_mutable_mapping(self):
  124. from collections import MutableMapping
  125. self.assertTrue(issubclass(ConfigurationView, MutableMapping))
  126. class test_ExceptionInfo(Case):
  127. def test_exception_info(self):
  128. try:
  129. raise LookupError('The quick brown fox jumps...')
  130. except Exception:
  131. einfo = ExceptionInfo()
  132. self.assertEqual(str(einfo), einfo.traceback)
  133. self.assertIsInstance(einfo.exception, LookupError)
  134. self.assertTupleEqual(
  135. einfo.exception.args, ('The quick brown fox jumps...',),
  136. )
  137. self.assertTrue(einfo.traceback)
  138. r = repr(einfo)
  139. self.assertTrue(r)
  140. class test_LimitedSet(Case):
  141. def setUp(self):
  142. if sys.platform == 'win32':
  143. raise SkipTest('Not working on Windows')
  144. def test_add(self):
  145. if sys.platform == 'win32':
  146. raise SkipTest('Not working properly on Windows')
  147. s = LimitedSet(maxlen=2)
  148. s.add('foo')
  149. s.add('bar')
  150. for n in 'foo', 'bar':
  151. self.assertIn(n, s)
  152. s.add('baz')
  153. for n in 'bar', 'baz':
  154. self.assertIn(n, s)
  155. self.assertNotIn('foo', s)
  156. s = LimitedSet(maxlen=10)
  157. for i in range(150):
  158. s.add(i)
  159. self.assertLessEqual(len(s), 10)
  160. # make sure heap is not leaking:
  161. self.assertLessEqual(
  162. len(s._heap),
  163. len(s) * (100. + s.max_heap_percent_overload) / 100,
  164. )
  165. def test_purge(self):
  166. # purge now enforces rules
  167. # cant purge(1) now. but .purge(now=...) still works
  168. s = LimitedSet(maxlen=10)
  169. [s.add(i) for i in range(10)]
  170. s.maxlen = 2
  171. s.purge()
  172. self.assertEqual(len(s), 2)
  173. # expired
  174. s = LimitedSet(maxlen=10, expires=1)
  175. [s.add(i) for i in range(10)]
  176. s.maxlen = 2
  177. s.purge(now=time() + 100)
  178. self.assertEqual(len(s), 0)
  179. # not expired
  180. s = LimitedSet(maxlen=None, expires=1)
  181. [s.add(i) for i in range(10)]
  182. s.maxlen = 2
  183. s.purge(now=lambda: time() - 100)
  184. self.assertEqual(len(s), 2)
  185. # expired -> minsize
  186. s = LimitedSet(maxlen=10, minlen=10, expires=1)
  187. [s.add(i) for i in range(20)]
  188. s.minlen = 3
  189. s.purge(now=time() + 3)
  190. self.assertEqual(s.minlen, len(s))
  191. self.assertLessEqual(
  192. len(s._heap),
  193. s.maxlen * (100. + s.max_heap_percent_overload) / 100,
  194. )
  195. def test_pickleable(self):
  196. s = LimitedSet(maxlen=2)
  197. s.add('foo')
  198. s.add('bar')
  199. self.assertEqual(pickle.loads(pickle.dumps(s)), s)
  200. def test_iter(self):
  201. if sys.platform == 'win32':
  202. raise SkipTest('Not working on Windows')
  203. s = LimitedSet(maxlen=3)
  204. items = ['foo', 'bar', 'baz', 'xaz']
  205. for item in items:
  206. s.add(item)
  207. l = list(iter(s))
  208. for item in items[1:]:
  209. self.assertIn(item, l)
  210. self.assertNotIn('foo', l)
  211. self.assertListEqual(l, items[1:], 'order by insertion time')
  212. def test_repr(self):
  213. s = LimitedSet(maxlen=2)
  214. items = 'foo', 'bar'
  215. for item in items:
  216. s.add(item)
  217. self.assertIn('LimitedSet(', repr(s))
  218. def test_discard(self):
  219. s = LimitedSet(maxlen=2)
  220. s.add('foo')
  221. s.discard('foo')
  222. self.assertNotIn('foo', s)
  223. self.assertEqual(len(s._data), 0)
  224. s.discard('foo')
  225. def test_clear(self):
  226. s = LimitedSet(maxlen=2)
  227. s.add('foo')
  228. s.add('bar')
  229. self.assertEqual(len(s), 2)
  230. s.clear()
  231. self.assertFalse(s)
  232. def test_update(self):
  233. s1 = LimitedSet(maxlen=2)
  234. s1.add('foo')
  235. s1.add('bar')
  236. s2 = LimitedSet(maxlen=2)
  237. s2.update(s1)
  238. self.assertItemsEqual(list(s2), ['foo', 'bar'])
  239. s2.update(['bla'])
  240. self.assertItemsEqual(list(s2), ['bla', 'bar'])
  241. s2.update(['do', 're'])
  242. self.assertItemsEqual(list(s2), ['do', 're'])
  243. s1 = LimitedSet(maxlen=10, expires=None)
  244. s2 = LimitedSet(maxlen=10, expires=None)
  245. s3 = LimitedSet(maxlen=10, expires=None)
  246. s4 = LimitedSet(maxlen=10, expires=None)
  247. s5 = LimitedSet(maxlen=10, expires=None)
  248. for i in range(12):
  249. s1.add(i)
  250. s2.add(i*i)
  251. s3.update(s1)
  252. s3.update(s2)
  253. s4.update(s1.as_dict())
  254. s4.update(s2.as_dict())
  255. s5.update(s1._data) # revoke is using this
  256. s5.update(s2._data)
  257. self.assertEqual(s3, s4)
  258. self.assertEqual(s3, s5)
  259. s2.update(s4)
  260. s4.update(s2)
  261. self.assertEqual(s2, s4)
  262. def test_iterable_and_ordering(self):
  263. s = LimitedSet(maxlen=35, expires=None)
  264. # we use a custom clock here, as time.time() does not have enough
  265. # precision when called quickly (can return the same value twice).
  266. clock = count(1)
  267. for i in reversed(range(15)):
  268. s.add(i, now=next(clock))
  269. j = 40
  270. for i in s:
  271. self.assertLess(i, j) # each item is smaller and smaller
  272. j = i
  273. self.assertEqual(i, 0) # last item is zero
  274. def test_pop_and_ordering_again(self):
  275. s = LimitedSet(maxlen=5)
  276. for i in range(10):
  277. s.add(i)
  278. j = -1
  279. for _ in range(5):
  280. i = s.pop()
  281. self.assertLess(j, i)
  282. i = s.pop()
  283. self.assertEqual(i, None)
  284. def test_as_dict(self):
  285. s = LimitedSet(maxlen=2)
  286. s.add('foo')
  287. self.assertIsInstance(s.as_dict(), Mapping)
  288. class test_AttributeDict(Case):
  289. def test_getattr__setattr(self):
  290. x = AttributeDict({'foo': 'bar'})
  291. self.assertEqual(x['foo'], 'bar')
  292. with self.assertRaises(AttributeError):
  293. x.bar
  294. x.bar = 'foo'
  295. self.assertEqual(x['bar'], 'foo')
  296. class test_DependencyGraph(Case):
  297. def graph1(self):
  298. return DependencyGraph([
  299. ('A', []),
  300. ('B', []),
  301. ('C', ['A']),
  302. ('D', ['C', 'B']),
  303. ])
  304. def test_repr(self):
  305. self.assertTrue(repr(self.graph1()))
  306. def test_topsort(self):
  307. order = self.graph1().topsort()
  308. # C must start before D
  309. self.assertLess(order.index('C'), order.index('D'))
  310. # and B must start before D
  311. self.assertLess(order.index('B'), order.index('D'))
  312. # and A must start before C
  313. self.assertLess(order.index('A'), order.index('C'))
  314. def test_edges(self):
  315. self.assertItemsEqual(
  316. list(self.graph1().edges()),
  317. ['C', 'D'],
  318. )
  319. def test_connect(self):
  320. x, y = self.graph1(), self.graph1()
  321. x.connect(y)
  322. def test_valency_of_when_missing(self):
  323. x = self.graph1()
  324. self.assertEqual(x.valency_of('foobarbaz'), 0)
  325. def test_format(self):
  326. x = self.graph1()
  327. x.formatter = Mock()
  328. obj = Mock()
  329. self.assertTrue(x.format(obj))
  330. x.formatter.assert_called_with(obj)
  331. x.formatter = None
  332. self.assertIs(x.format(obj), obj)
  333. def test_items(self):
  334. self.assertDictEqual(
  335. dict(items(self.graph1())),
  336. {'A': [], 'B': [], 'C': ['A'], 'D': ['C', 'B']},
  337. )
  338. def test_repr_node(self):
  339. x = self.graph1()
  340. self.assertTrue(x.repr_node('fasdswewqewq'))
  341. def test_to_dot(self):
  342. s = WhateverIO()
  343. self.graph1().to_dot(s)
  344. self.assertTrue(s.getvalue())