test_datastructures.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434
  1. from __future__ import absolute_import
  2. import pickle
  3. import sys
  4. from collections import Mapping
  5. from itertools import count
  6. from billiard.einfo import ExceptionInfo
  7. from time import time
  8. from celery.datastructures import (
  9. LimitedSet,
  10. AttributeDict,
  11. DictAttribute,
  12. ConfigurationView,
  13. DependencyGraph,
  14. )
  15. from celery.five import items
  16. from celery.utils.objects import Bunch
  17. from celery.tests.case import Case, Mock, WhateverIO, SkipTest
  18. class test_DictAttribute(Case):
  19. def test_get_set_keys_values_items(self):
  20. x = DictAttribute(Bunch())
  21. x['foo'] = 'The quick brown fox'
  22. self.assertEqual(x['foo'], 'The quick brown fox')
  23. self.assertEqual(x['foo'], x.obj.foo)
  24. self.assertEqual(x.get('foo'), 'The quick brown fox')
  25. self.assertIsNone(x.get('bar'))
  26. with self.assertRaises(KeyError):
  27. x['bar']
  28. x.foo = 'The quick yellow fox'
  29. self.assertEqual(x['foo'], 'The quick yellow fox')
  30. self.assertIn(
  31. ('foo', 'The quick yellow fox'),
  32. list(x.items()),
  33. )
  34. self.assertIn('foo', list(x.keys()))
  35. self.assertIn('The quick yellow fox', list(x.values()))
  36. def test_setdefault(self):
  37. x = DictAttribute(Bunch())
  38. x.setdefault('foo', 'NEW')
  39. self.assertEqual(x['foo'], 'NEW')
  40. x.setdefault('foo', 'XYZ')
  41. self.assertEqual(x['foo'], 'NEW')
  42. def test_contains(self):
  43. x = DictAttribute(Bunch())
  44. x['foo'] = 1
  45. self.assertIn('foo', x)
  46. self.assertNotIn('bar', x)
  47. def test_items(self):
  48. obj = Bunch(attr1=1)
  49. x = DictAttribute(obj)
  50. x['attr2'] = 2
  51. self.assertEqual(x['attr1'], 1)
  52. self.assertEqual(x['attr2'], 2)
  53. class test_ConfigurationView(Case):
  54. def setUp(self):
  55. self.view = ConfigurationView({'changed_key': 1,
  56. 'both': 2},
  57. [{'default_key': 1,
  58. 'both': 1}])
  59. def test_setdefault(self):
  60. self.view.setdefault('both', 36)
  61. self.assertEqual(self.view['both'], 2)
  62. self.view.setdefault('new', 36)
  63. self.assertEqual(self.view['new'], 36)
  64. def test_get(self):
  65. self.assertEqual(self.view.get('both'), 2)
  66. sp = object()
  67. self.assertIs(self.view.get('nonexisting', sp), sp)
  68. def test_update(self):
  69. changes = dict(self.view.changes)
  70. self.view.update(a=1, b=2, c=3)
  71. self.assertDictEqual(self.view.changes,
  72. dict(changes, a=1, b=2, c=3))
  73. def test_contains(self):
  74. self.assertIn('changed_key', self.view)
  75. self.assertIn('default_key', self.view)
  76. self.assertNotIn('new', self.view)
  77. def test_repr(self):
  78. self.assertIn('changed_key', repr(self.view))
  79. self.assertIn('default_key', repr(self.view))
  80. def test_iter(self):
  81. expected = {'changed_key': 1,
  82. 'default_key': 1,
  83. 'both': 2}
  84. self.assertDictEqual(dict(items(self.view)), expected)
  85. self.assertItemsEqual(list(iter(self.view)),
  86. list(expected.keys()))
  87. self.assertItemsEqual(list(self.view.keys()), list(expected.keys()))
  88. self.assertItemsEqual(
  89. list(self.view.values()),
  90. list(expected.values()),
  91. )
  92. self.assertIn('changed_key', list(self.view.keys()))
  93. self.assertIn(2, list(self.view.values()))
  94. self.assertIn(('both', 2), list(self.view.items()))
  95. def test_add_defaults_dict(self):
  96. defaults = {'foo': 10}
  97. self.view.add_defaults(defaults)
  98. self.assertEqual(self.view.foo, 10)
  99. def test_add_defaults_object(self):
  100. defaults = Bunch(foo=10)
  101. self.view.add_defaults(defaults)
  102. self.assertEqual(self.view.foo, 10)
  103. def test_clear(self):
  104. self.view.clear()
  105. self.assertEqual(self.view.both, 1)
  106. self.assertNotIn('changed_key', self.view)
  107. def test_bool(self):
  108. self.assertTrue(bool(self.view))
  109. self.view._order[:] = []
  110. self.assertFalse(bool(self.view))
  111. def test_len(self):
  112. self.assertEqual(len(self.view), 3)
  113. self.view.KEY = 33
  114. self.assertEqual(len(self.view), 4)
  115. self.view.clear()
  116. self.assertEqual(len(self.view), 2)
  117. def test_isa_mapping(self):
  118. from collections import Mapping
  119. self.assertTrue(issubclass(ConfigurationView, Mapping))
  120. def test_isa_mutable_mapping(self):
  121. from collections import MutableMapping
  122. self.assertTrue(issubclass(ConfigurationView, MutableMapping))
  123. class test_ExceptionInfo(Case):
  124. def test_exception_info(self):
  125. try:
  126. raise LookupError('The quick brown fox jumps...')
  127. except Exception:
  128. einfo = ExceptionInfo()
  129. self.assertEqual(str(einfo), einfo.traceback)
  130. self.assertIsInstance(einfo.exception, LookupError)
  131. self.assertTupleEqual(
  132. einfo.exception.args, ('The quick brown fox jumps...',),
  133. )
  134. self.assertTrue(einfo.traceback)
  135. r = repr(einfo)
  136. self.assertTrue(r)
  137. class test_LimitedSet(Case):
  138. def setUp(self):
  139. if sys.platform == 'win32':
  140. raise SkipTest('Not working on Windows')
  141. def test_add(self):
  142. if sys.platform == 'win32':
  143. raise SkipTest('Not working properly on Windows')
  144. s = LimitedSet(maxlen=2)
  145. s.add('foo')
  146. s.add('bar')
  147. for n in 'foo', 'bar':
  148. self.assertIn(n, s)
  149. s.add('baz')
  150. for n in 'bar', 'baz':
  151. self.assertIn(n, s)
  152. self.assertNotIn('foo', s)
  153. s = LimitedSet(maxlen=10)
  154. for i in range(150):
  155. s.add(i)
  156. self.assertLessEqual(len(s), 10)
  157. # make sure heap is not leaking:
  158. self.assertLessEqual(
  159. len(s._heap),
  160. len(s) * (100. + s.max_heap_percent_overload) / 100,
  161. )
  162. def test_purge(self):
  163. # purge now enforces rules
  164. # cant purge(1) now. but .purge(now=...) still works
  165. s = LimitedSet(maxlen=10)
  166. [s.add(i) for i in range(10)]
  167. s.maxlen = 2
  168. s.purge()
  169. self.assertEqual(len(s), 2)
  170. # expired
  171. s = LimitedSet(maxlen=10, expires=1)
  172. [s.add(i) for i in range(10)]
  173. s.maxlen = 2
  174. s.purge(now=time() + 100)
  175. self.assertEqual(len(s), 0)
  176. # not expired
  177. s = LimitedSet(maxlen=None, expires=1)
  178. [s.add(i) for i in range(10)]
  179. s.maxlen = 2
  180. s.purge(now=lambda: time() - 100)
  181. self.assertEqual(len(s), 2)
  182. # expired -> minsize
  183. s = LimitedSet(maxlen=10, minlen=10, expires=1)
  184. [s.add(i) for i in range(20)]
  185. s.minlen = 3
  186. s.purge(now=time() + 3)
  187. self.assertEqual(s.minlen, len(s))
  188. self.assertLessEqual(
  189. len(s._heap),
  190. s.maxlen * (100. + s.max_heap_percent_overload) / 100,
  191. )
  192. def test_pickleable(self):
  193. s = LimitedSet(maxlen=2)
  194. s.add('foo')
  195. s.add('bar')
  196. self.assertEqual(pickle.loads(pickle.dumps(s)), s)
  197. def test_iter(self):
  198. if sys.platform == 'win32':
  199. raise SkipTest('Not working on Windows')
  200. s = LimitedSet(maxlen=3)
  201. items = ['foo', 'bar', 'baz', 'xaz']
  202. for item in items:
  203. s.add(item)
  204. l = list(iter(s))
  205. for item in items[1:]:
  206. self.assertIn(item, l)
  207. self.assertNotIn('foo', l)
  208. self.assertListEqual(l, items[1:], 'order by insertion time')
  209. def test_repr(self):
  210. s = LimitedSet(maxlen=2)
  211. items = 'foo', 'bar'
  212. for item in items:
  213. s.add(item)
  214. self.assertIn('LimitedSet(', repr(s))
  215. def test_discard(self):
  216. s = LimitedSet(maxlen=2)
  217. s.add('foo')
  218. s.discard('foo')
  219. self.assertNotIn('foo', s)
  220. self.assertEqual(len(s._data), 0)
  221. s.discard('foo')
  222. def test_clear(self):
  223. s = LimitedSet(maxlen=2)
  224. s.add('foo')
  225. s.add('bar')
  226. self.assertEqual(len(s), 2)
  227. s.clear()
  228. self.assertFalse(s)
  229. def test_update(self):
  230. s1 = LimitedSet(maxlen=2)
  231. s1.add('foo')
  232. s1.add('bar')
  233. s2 = LimitedSet(maxlen=2)
  234. s2.update(s1)
  235. self.assertItemsEqual(list(s2), ['foo', 'bar'])
  236. s2.update(['bla'])
  237. self.assertItemsEqual(list(s2), ['bla', 'bar'])
  238. s2.update(['do', 're'])
  239. self.assertItemsEqual(list(s2), ['do', 're'])
  240. s1 = LimitedSet(maxlen=10, expires=None)
  241. s2 = LimitedSet(maxlen=10, expires=None)
  242. s3 = LimitedSet(maxlen=10, expires=None)
  243. s4 = LimitedSet(maxlen=10, expires=None)
  244. s5 = LimitedSet(maxlen=10, expires=None)
  245. for i in range(12):
  246. s1.add(i)
  247. s2.add(i*i)
  248. s3.update(s1)
  249. s3.update(s2)
  250. s4.update(s1.as_dict())
  251. s4.update(s2.as_dict())
  252. s5.update(s1._data) # revoke is using this
  253. s5.update(s2._data)
  254. self.assertEqual(s3, s4)
  255. self.assertEqual(s3, s5)
  256. s2.update(s4)
  257. s4.update(s2)
  258. self.assertEqual(s2, s4)
  259. def test_iterable_and_ordering(self):
  260. s = LimitedSet(maxlen=35, expires=None)
  261. # we use a custom clock here, as time.time() does not have enough
  262. # precision when called quickly (can return the same value twice).
  263. clock = count(1)
  264. for i in reversed(range(15)):
  265. s.add(i, now=next(clock))
  266. j = 40
  267. for i in s:
  268. self.assertLess(i, j) # each item is smaller and smaller
  269. j = i
  270. self.assertEqual(i, 0) # last item is zero
  271. def test_pop_and_ordering_again(self):
  272. s = LimitedSet(maxlen=5)
  273. for i in range(10):
  274. s.add(i)
  275. j = -1
  276. for _ in range(5):
  277. i = s.pop()
  278. self.assertLess(j, i)
  279. i = s.pop()
  280. self.assertEqual(i, None)
  281. def test_as_dict(self):
  282. s = LimitedSet(maxlen=2)
  283. s.add('foo')
  284. self.assertIsInstance(s.as_dict(), Mapping)
  285. def test_add_removes_duplicate_from_small_heap(self):
  286. s = LimitedSet(maxlen=2)
  287. s.add('foo')
  288. s.add('foo')
  289. s.add('foo')
  290. self.assertEqual(len(s), 1)
  291. self.assertEqual(len(s._data), 1)
  292. self.assertEqual(len(s._heap), 1)
  293. def test_add_removes_duplicate_from_big_heap(self):
  294. s = LimitedSet(maxlen=1000)
  295. [s.add(i) for i in range(2000)]
  296. self.assertEqual(len(s), 1000)
  297. [s.add('foo') for i in range(1000)]
  298. # heap is refreshed when 15% larger than _data
  299. self.assertLess(len(s._heap), 1150)
  300. [s.add('foo') for i in range(1000)]
  301. self.assertLess(len(s._heap), 1150)
  302. def assert_lengths(self, s, expected, expected_data, expected_heap):
  303. self.assertEqual(len(s), expected)
  304. self.assertEqual(len(s._data), expected_data)
  305. self.assertEqual(len(s._heap), expected_heap)
  306. class test_AttributeDict(Case):
  307. def test_getattr__setattr(self):
  308. x = AttributeDict({'foo': 'bar'})
  309. self.assertEqual(x['foo'], 'bar')
  310. with self.assertRaises(AttributeError):
  311. x.bar
  312. x.bar = 'foo'
  313. self.assertEqual(x['bar'], 'foo')
  314. class test_DependencyGraph(Case):
  315. def graph1(self):
  316. return DependencyGraph([
  317. ('A', []),
  318. ('B', []),
  319. ('C', ['A']),
  320. ('D', ['C', 'B']),
  321. ])
  322. def test_repr(self):
  323. self.assertTrue(repr(self.graph1()))
  324. def test_topsort(self):
  325. order = self.graph1().topsort()
  326. # C must start before D
  327. self.assertLess(order.index('C'), order.index('D'))
  328. # and B must start before D
  329. self.assertLess(order.index('B'), order.index('D'))
  330. # and A must start before C
  331. self.assertLess(order.index('A'), order.index('C'))
  332. def test_edges(self):
  333. self.assertItemsEqual(
  334. list(self.graph1().edges()),
  335. ['C', 'D'],
  336. )
  337. def test_connect(self):
  338. x, y = self.graph1(), self.graph1()
  339. x.connect(y)
  340. def test_valency_of_when_missing(self):
  341. x = self.graph1()
  342. self.assertEqual(x.valency_of('foobarbaz'), 0)
  343. def test_format(self):
  344. x = self.graph1()
  345. x.formatter = Mock()
  346. obj = Mock()
  347. self.assertTrue(x.format(obj))
  348. x.formatter.assert_called_with(obj)
  349. x.formatter = None
  350. self.assertIs(x.format(obj), obj)
  351. def test_items(self):
  352. self.assertDictEqual(
  353. dict(items(self.graph1())),
  354. {'A': [], 'B': [], 'C': ['A'], 'D': ['C', 'B']},
  355. )
  356. def test_repr_node(self):
  357. x = self.graph1()
  358. self.assertTrue(x.repr_node('fasdswewqewq'))
  359. def test_to_dot(self):
  360. s = WhateverIO()
  361. self.graph1().to_dot(s)
  362. self.assertTrue(s.getvalue())