test_buckets.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. import sys
  2. import os
  3. sys.path.insert(0, os.getcwd())
  4. import unittest
  5. import time
  6. from celery import buckets
  7. from celery.task.base import Task
  8. from celery.registry import TaskRegistry
  9. from celery.utils import gen_unique_id
  10. from itertools import chain, izip
  11. class MockJob(object):
  12. def __init__(self, task_id, task_name, args, kwargs):
  13. self.task_id = task_id
  14. self.task_name = task_name
  15. self.args = args
  16. self.kwargs = kwargs
  17. def __eq__(self, other):
  18. if isinstance(other, self.__class__):
  19. return bool(self.task_id == other.task_id \
  20. and self.task_name == other.task_name \
  21. and self.args == other.args \
  22. and self.kwargs == other.kwargs)
  23. else:
  24. return self == other
  25. def __repr__(self):
  26. return "<MockJob: task:%s id:%s args:%s kwargs:%s" % (
  27. self.task_name, self.task_id, self.args, self.kwargs)
  28. class TestTokenBucketQueue(unittest.TestCase):
  29. def empty_queue_yields_QueueEmpty(self):
  30. x = buckets.TokenBucketQueue(fill_rate=10)
  31. self.assertRaises(buckets.QueueEmpty, x.get)
  32. def test_bucket__put_get(self):
  33. x = buckets.TokenBucketQueue(fill_rate=10)
  34. x.put("The quick brown fox")
  35. self.assertEquals(x.get(), "The quick brown fox")
  36. x.put_nowait("The lazy dog")
  37. time.sleep(0.2)
  38. self.assertEquals(x.get_nowait(), "The lazy dog")
  39. def test_fill_rate(self):
  40. x = buckets.TokenBucketQueue(fill_rate=10)
  41. # 20 items should take at least one second to complete
  42. time_start = time.time()
  43. [x.put(str(i)) for i in xrange(20)]
  44. for i in xrange(20):
  45. sys.stderr.write("x")
  46. x.wait()
  47. self.assertTrue(time.time() - time_start > 1.5)
  48. def test_can_consume(self):
  49. x = buckets.TokenBucketQueue(fill_rate=1)
  50. x.put("The quick brown fox")
  51. self.assertEqual(x.get(), "The quick brown fox")
  52. time.sleep(0.1)
  53. # Not yet ready for another token
  54. x.put("The lazy dog")
  55. self.assertRaises(x.RateLimitExceeded, x.get)
  56. def test_expected_time(self):
  57. x = buckets.TokenBucketQueue(fill_rate=1)
  58. x.put_nowait("The quick brown fox")
  59. self.assertEqual(x.get_nowait(), "The quick brown fox")
  60. self.assertTrue(x.expected_time())
  61. def test_qsize(self):
  62. x = buckets.TokenBucketQueue(fill_rate=1)
  63. x.put("The quick brown fox")
  64. self.assertEqual(x.qsize(), 1)
  65. self.assertTrue(x.get_nowait(), "The quick brown fox")
  66. class TestRateLimitString(unittest.TestCase):
  67. def test_conversion(self):
  68. self.assertEquals(buckets.parse_ratelimit_string(999), 999)
  69. self.assertEquals(buckets.parse_ratelimit_string("1456/s"), 1456)
  70. self.assertEquals(buckets.parse_ratelimit_string("100/m"),
  71. 100 / 60.0)
  72. self.assertEquals(buckets.parse_ratelimit_string("10/h"),
  73. 10 / 60.0 / 60.0)
  74. self.assertEquals(buckets.parse_ratelimit_string("0xffec/s"), 0xffec)
  75. self.assertEquals(buckets.parse_ratelimit_string("0xcda/m"),
  76. 0xcda / 60.0)
  77. self.assertEquals(buckets.parse_ratelimit_string("0xF/h"),
  78. 0xf / 60.0 / 60.0)
  79. for zero in ("0x0", "0b0", "0o0", 0, None, "0/m", "0/h", "0/s"):
  80. self.assertEquals(buckets.parse_ratelimit_string(zero), 0)
  81. class TaskA(Task):
  82. rate_limit = 10
  83. class TaskB(Task):
  84. rate_limit = None
  85. class TaskC(Task):
  86. rate_limit = "1/s"
  87. class TaskD(Task):
  88. rate_limit = "1000/m"
  89. class TestTaskBuckets(unittest.TestCase):
  90. def setUp(self):
  91. self.registry = TaskRegistry()
  92. self.task_classes = (TaskA, TaskB, TaskC)
  93. for task_cls in self.task_classes:
  94. self.registry.register(task_cls)
  95. def test_auto_add_on_missing(self):
  96. b = buckets.TaskBucket(task_registry=self.registry)
  97. for task_cls in self.task_classes:
  98. self.assertTrue(task_cls.name in b.buckets.keys())
  99. self.registry.register(TaskD)
  100. self.assertTrue(b.get_bucket_for_type(TaskD.name))
  101. self.assertTrue(TaskD.name in b.buckets.keys())
  102. self.registry.unregister(TaskD)
  103. def test_has_rate_limits(self):
  104. b = buckets.TaskBucket(task_registry=self.registry)
  105. self.assertEqual(b.buckets[TaskA.name].fill_rate, 10)
  106. self.assertTrue(isinstance(b.buckets[TaskB.name], buckets.Queue))
  107. self.assertEqual(b.buckets[TaskC.name].fill_rate, 1)
  108. self.registry.register(TaskD)
  109. b.init_with_registry()
  110. try:
  111. self.assertEqual(b.buckets[TaskD.name].fill_rate, 1000 / 60.0)
  112. finally:
  113. self.registry.unregister(TaskD)
  114. def test_on_empty_buckets__get_raises_empty(self):
  115. b = buckets.TaskBucket(task_registry=self.registry)
  116. self.assertRaises(buckets.QueueEmpty, b.get)
  117. self.assertEqual(b.qsize(), 0)
  118. def test_put__get(self):
  119. b = buckets.TaskBucket(task_registry=self.registry)
  120. job = MockJob(gen_unique_id(), TaskA.name, ["theqbf"], {"foo": "bar"})
  121. b.put(job)
  122. self.assertEquals(b.get(), job)
  123. def test_fill_rate(self):
  124. b = buckets.TaskBucket(task_registry=self.registry)
  125. cjob = lambda i: MockJob(gen_unique_id(), TaskA.name, [i], {})
  126. jobs = [cjob(i) for i in xrange(20)]
  127. [b.put(job) for job in jobs]
  128. self.assertEqual(b.qsize(), 20)
  129. # 20 items should take at least one second to complete
  130. time_start = time.time()
  131. for i, job in enumerate(jobs):
  132. sys.stderr.write("i")
  133. self.assertEqual(b.get(), job)
  134. self.assertTrue(time.time() - time_start > 1.5)
  135. def test__very_busy_queue_doesnt_block_others(self):
  136. b = buckets.TaskBucket(task_registry=self.registry)
  137. cjob = lambda i, t: MockJob(gen_unique_id(), t.name, [i], {})
  138. ajobs = [cjob(i, TaskA) for i in xrange(10)]
  139. bjobs = [cjob(i, TaskB) for i in xrange(20)]
  140. jobs = list(chain(*izip(bjobs, ajobs)))
  141. map(b.put, jobs)
  142. got_ajobs = 0
  143. for job in (b.get() for i in xrange(20)):
  144. if job.task_name == TaskA.name:
  145. got_ajobs += 1
  146. self.assertTrue(got_ajobs > 2)
  147. def test_thorough__multiple_types(self):
  148. self.registry.register(TaskD)
  149. try:
  150. b = buckets.TaskBucket(task_registry=self.registry)
  151. cjob = lambda i, t: MockJob(gen_unique_id(), t.name, [i], {})
  152. ajobs = [cjob(i, TaskA) for i in xrange(10)]
  153. bjobs = [cjob(i, TaskB) for i in xrange(10)]
  154. cjobs = [cjob(i, TaskC) for i in xrange(10)]
  155. djobs = [cjob(i, TaskD) for i in xrange(10)]
  156. # Spread the jobs around.
  157. jobs = list(chain(*izip(ajobs, bjobs, cjobs, djobs)))
  158. [b.put(job) for job in jobs]
  159. for i, job in enumerate(jobs):
  160. sys.stderr.write("0")
  161. self.assertTrue(b.get(), job)
  162. self.assertEqual(i+1, len(jobs))
  163. finally:
  164. self.registry.unregister(TaskD)
  165. if __name__ == "__main__":
  166. unittest.main()