瀏覽代碼

Shiny new unittests for the rate limiting feature.

Ask Solem 15 年之前
父節點
當前提交
1cf16ed918
共有 3 個文件被更改,包括 309 次插入22 次删除
  1. 112 21
      celery/buckets.py
  2. 195 0
      celery/tests/test_buckets.py
  3. 2 1
      celery/tests/test_worker.py

+ 112 - 21
celery/buckets.py

@@ -5,21 +5,29 @@ RATE_MODIFIER_MAP = {"s": lambda n: n,
                      "m": lambda n: n / 60.0,
                      "h": lambda n: n / 60.0 / 60.0}
 
+BASE_IDENTIFIERS = {"0x": 16,
+                    "0o": 8,
+                    "0b": 2}
 
-class BucketRateExceeded(Exception):
+
+class RateLimitExceeded(Exception):
     """The token buckets rate limit has been exceeded."""
 
 
 def parse_ratelimit_string(rate_limit):
     """Parse rate limit configurations such as ``"100/m"`` or ``"2/h"``
         and convert them into seconds."""
+
     if rate_limit:
-        try:
-            return int(rate_limit)
-        except ValueError:
-            ops, _, modifier = rate_limit.partition("/")
-            return RATE_MODIFIER_MAP[modifier](int(ops))
-    return None
+        if isinstance(rate_limit, basestring):
+            base = BASE_IDENTIFIERS.get(rate_limit[:2], 10)
+            try:
+                return int(rate_limit, base)
+            except ValueError:
+                ops, _, modifier = rate_limit.partition("/")
+                return RATE_MODIFIER_MAP[modifier](int(ops, base)) or 0
+        return rate_limit or 0
+    return 0
 
 
 class TaskBucket(object):
@@ -50,16 +58,32 @@ class TaskBucket(object):
 
 
     """
-    min_wait = 0.05
+    min_wait = 0.0
 
     def __init__(self, task_registry):
         self.task_registry = task_registry
         self.buckets = {}
         self.init_with_registry()
 
-    def put(self, task):
+    def put(self, job):
         """Put a task into the appropiate bucket."""
-        self.buckets[task_name].put_nowait(task)
+        self.buckets[job.task_name].put_nowait(job)
+    put_nowait = put
+
+    def _get(self):
+        remainding_times = []
+        for bucket in self.buckets.values():
+            remainding = bucket.expected_time()
+            if not remainding:
+                try:
+                    return 0, bucket.get_nowait()
+                except QueueEmpty:
+                    pass
+            else:
+                remainding_times.append(remainding)
+        if not remainding_times:
+            raise QueueEmpty
+        return min(remainding_times), None
 
     def get(self, timeout=None):
         """Retrive the task from the first available bucket.
@@ -68,18 +92,33 @@ class TaskBucket(object):
         consume tokens from it.
 
         """
+        time_start = time.time()
+        did_timeout = lambda: timeout and time.time() - time_start > timeout
+
+        while True:
+            remainding_time, item = self._get()
+            if remainding_time:
+                if did_timeout():
+                    raise QueueEmpty
+                time.sleep(remainding_time)
+            else:
+                return item
+    get_nowait = get
+
+    def __old_get(self, block=True, timeout=None):
         time_spent = 0
         for bucket in self.buckets.values():
             remaining_times = []
             try:
                 return bucket.get_nowait()
-            except BucketRateExceeded:
+            except RateLimitExceeded:
                 remaining_times.append(bucket.expected_time())
             except QueueEmpty:
                 pass
 
-            if timeout and time_spent >= timeout:
-                raise QueueEmpty()
+            if not remaining_times:
+                if not block or (timeout and time_spent >= timeout):
+                    raise QueueEmpty
             else:
                 shortest_wait = min(remaining_times or [self.min_wait])
                 time_spent += shortest_wait
@@ -109,14 +148,22 @@ class TaskBucket(object):
         rate_limit = parse_ratelimit_string(rate_limit)
         if rate_limit:
             task_queue = TokenBucketQueue(rate_limit, queue=task_queue)
+        else:
+            task_queue.expected_time = lambda: 0
 
         self.buckets[task_name] = task_queue
         return task_queue
 
+    def qsize(self):
+        """Get the total size of all the queues."""
+        return sum(bucket.qsize() for bucket in self.buckets.values())
+
 
 class TokenBucketQueue(object):
-    """An implementation of the token bucket algorithm.
+    """Queue with rate limited get operations.
 
+    This uses the token bucket algorithm to rate limit the queue on get
+    operations.
     See http://en.wikipedia.org/wiki/Token_Bucket
     Most of this code was stolen from an entry in the ASPN Python Cookbook:
     http://code.activestate.com/recipes/511490/
@@ -137,6 +184,8 @@ class TokenBucketQueue(object):
         Timestamp of the last time a token was taken out of the bucket.
 
     """
+    RateLimitExceeded = RateLimitExceeded
+
     def __init__(self, fill_rate, queue=None, capacity=1):
         self.capacity = float(capacity)
         self._tokens = self.capacity
@@ -146,27 +195,69 @@ class TokenBucketQueue(object):
         self.fill_rate = float(fill_rate)
         self.timestamp = time.time()
 
-    def put(self, item, nb=False):
-        put = self.queue.put_nowait if nb else self.queue.put
+    def put(self, item, block=True):
+        """Put an item into the queue.
+
+        Also see :meth:`Queue.Queue.put`.
+
+        """
+        put = self.queue.put if block else self.queue.put_nowait
         put(item)
 
-    def get(self, nb=False):
-        get = self.queue.get_nowait if nb else self.queue.get
+    def get(self, block=True):
+        """Remove and return an item from the queue.
+
+        :raises RateLimitExceeded: If a token could not be consumed from the
+            token bucket (consuming from the queue too fast).
+        :raises Queue.Empty: If an item is not immediately available.
+
+        Also see :meth:`Queue.Queue.get`.
+
+        """
+        get = self.queue.get if block else self.queue.get_nowait
 
         if not self.can_consume(1):
-            raise BucketRateExceeded()
+            raise RateLimitExceeded
 
         return get()
 
     def get_nowait(self):
-        return self.get(nb=True)
+        """Remove and return an item from the queue without blocking.
+
+        :raises RateLimitExceeded: If a token could not be consumed from the
+            token bucket (consuming from the queue too fast).
+        :raises Queue.Empty: If an item is not immediately available.
+
+        Also see :meth:`Queue.Queue.get_nowait`."""
+        return self.get(block=False)
 
     def put_nowait(self, item):
-        return self.put(item, nb=True)
+        """Put an item into the queue without blocking.
+
+        :raises Queue.Full: If a free slot is not immediately available.
+
+        Also see :meth:`Queue.Queue.put_nowait`
+
+        """
+        return self.put(item, block=False)
 
     def qsize(self):
+        """Returns the size of the queue.
+
+        See :meth:`Queue.Queue.qsize`.
+
+        """
         return self.queue.qsize()
 
+    def wait(self, block=False):
+        """Wait until a token can be retrieved from the bucket and return
+        the next item."""
+        while True:
+            remaining = self.expected_time()
+            if not remaining:
+                return self.get(block=block)
+            time.sleep(remaining)
+
     def can_consume(self, tokens=1):
         """Consume tokens from the bucket. Returns True if there were
         sufficient tokens otherwise False."""

+ 195 - 0
celery/tests/test_buckets.py

@@ -0,0 +1,195 @@
+import sys
+import os
+sys.path.insert(0, os.getcwd())
+import unittest
+import time
+from celery import buckets
+from celery.task.base import Task
+from celery.registry import TaskRegistry
+from celery.utils import gen_unique_id
+from itertools import chain, izip
+
+
+class MockJob(object):
+
+    def __init__(self, task_id, task_name, args, kwargs):
+        self.task_id = task_id
+        self.task_name = task_name
+        self.args = args
+        self.kwargs = kwargs
+
+    def __eq__(self, other):
+        if isinstance(other, self.__class__):
+            return bool(self.task_id == other.task_id \
+                    and self.task_name == other.task_name \
+                    and self.args == other.args \
+                    and self.kwargs == other.kwargs)
+        else:
+            return self == other
+
+    def __repr__(self):
+        return "<MockJob: task:%s id:%s args:%s kwargs:%s" % (
+                self.task_name, self.task_id, self.args, self.kwargs)
+
+
+class TestTokenBucketQueue(unittest.TestCase):
+
+    def empty_queue_yields_QueueEmpty(self):
+        x = buckets.TokenBucketQueue(fill_rate=10)
+        self.assertRaises(buckets.QueueEmpty, x.get)
+
+    def test_bucket__put_get(self):
+        x = buckets.TokenBucketQueue(fill_rate=10)
+        x.put("The quick brown fox")
+        self.assertEquals(x.get(), "The quick brown fox")
+
+        x.put_nowait("The lazy dog")
+        time.sleep(0.2)
+        self.assertEquals(x.get_nowait(), "The lazy dog")
+
+    def test_fill_rate(self):
+        x = buckets.TokenBucketQueue(fill_rate=10)
+        # 20 items should take at least one second to complete
+        time_start = time.time()
+        [x.put(str(i)) for i in xrange(20)]
+        for i in xrange(20):
+            sys.stderr.write("x")
+            x.wait()
+        self.assertTrue(time.time() - time_start > 1.5)
+
+    def test_can_consume(self):
+        x = buckets.TokenBucketQueue(fill_rate=1)
+        x.put("The quick brown fox")
+        self.assertEqual(x.get(), "The quick brown fox")
+        time.sleep(0.1)
+        # Not yet ready for another token
+        x.put("The lazy dog")
+        self.assertRaises(x.RateLimitExceeded, x.get)
+
+    def test_expected_time(self):
+        x = buckets.TokenBucketQueue(fill_rate=1)
+        x.put_nowait("The quick brown fox")
+        self.assertEqual(x.get_nowait(), "The quick brown fox")
+        self.assertTrue(x.expected_time())
+
+    def test_qsize(self):
+        x = buckets.TokenBucketQueue(fill_rate=1)
+        x.put("The quick brown fox")
+        self.assertEqual(x.qsize(), 1)
+        self.assertTrue(x.get_nowait(), "The quick brown fox")
+
+
+class TestRateLimitString(unittest.TestCase):
+
+    def test_conversion(self):
+        self.assertEquals(buckets.parse_ratelimit_string(999), 999)
+        self.assertEquals(buckets.parse_ratelimit_string("1456/s"), 1456)
+        self.assertEquals(buckets.parse_ratelimit_string("100/m"),
+                          100 / 60.0)
+        self.assertEquals(buckets.parse_ratelimit_string("10/h"),
+                          10 / 60.0 / 60.0)
+        self.assertEquals(buckets.parse_ratelimit_string("0xffec/s"), 0xffec)
+        self.assertEquals(buckets.parse_ratelimit_string("0xcda/m"), 
+                          0xcda / 60.0)
+        self.assertEquals(buckets.parse_ratelimit_string("0xF/h"),
+                          0xf / 60.0 / 60.0)
+
+        for zero in ("0x0", "0b0", "0o0", 0, None, "0/m", "0/h", "0/s"):
+            self.assertEquals(buckets.parse_ratelimit_string(zero), 0)
+
+
+class TaskA(Task):
+    rate_limit = 10
+
+
+class TaskB(Task):
+    rate_limit = None
+
+
+class TaskC(Task):
+    rate_limit = "1/s"
+
+
+class TaskD(Task):
+    rate_limit = "1000/m"
+
+
+class TestTaskBuckets(unittest.TestCase):
+
+    def setUp(self):
+        self.registry = TaskRegistry()
+        self.task_classes = (TaskA, TaskB, TaskC)
+        for task_cls in self.task_classes:
+            self.registry.register(task_cls)
+
+    def test_auto_add_on_missing(self):
+        b = buckets.TaskBucket(task_registry=self.registry)
+        for task_cls in self.task_classes:
+            self.assertTrue(task_cls.name in b.buckets.keys())
+        self.registry.register(TaskD)
+        self.assertTrue(b.get_bucket_for_type(TaskD.name))
+        self.assertTrue(TaskD.name in b.buckets.keys())
+        self.registry.unregister(TaskD)
+
+    def test_has_rate_limits(self):
+        b = buckets.TaskBucket(task_registry=self.registry)
+        self.assertEqual(b.buckets[TaskA.name].fill_rate, 10)
+        self.assertTrue(isinstance(b.buckets[TaskB.name], buckets.Queue))
+        self.assertEqual(b.buckets[TaskC.name].fill_rate, 1)
+        self.registry.register(TaskD)
+        b.init_with_registry()
+        try:
+            self.assertEqual(b.buckets[TaskD.name].fill_rate, 1000 / 60.0)
+        finally:
+            self.registry.unregister(TaskD)
+
+    def test_on_empty_buckets__get_raises_empty(self):
+        b = buckets.TaskBucket(task_registry=self.registry)
+        self.assertRaises(buckets.QueueEmpty, b.get)
+        self.assertEqual(b.qsize(), 0)
+
+    def test_put__get(self):
+        b = buckets.TaskBucket(task_registry=self.registry)
+        job = MockJob(gen_unique_id(), TaskA.name, ["theqbf"], {"foo": "bar"})
+        b.put(job)
+        self.assertEquals(b.get(), job)
+
+    def test_fill_rate(self):
+        b = buckets.TaskBucket(task_registry=self.registry)
+
+        cjob = lambda i: MockJob(gen_unique_id(), TaskA.name, [i], {})
+        jobs = [cjob(i) for i in xrange(20)]
+        [b.put(job) for job in jobs]
+
+        self.assertEqual(b.qsize(), 20)
+
+        # 20 items should take at least one second to complete
+        time_start = time.time()
+        for i, job in enumerate(jobs):
+            sys.stderr.write("i")
+            self.assertEqual(b.get(), job)
+        self.assertTrue(time.time() - time_start > 1.5)
+
+    def test_thorough__multiple_types(self):
+        self.registry.register(TaskD)
+        try:
+            b = buckets.TaskBucket(task_registry=self.registry)
+
+            cjob = lambda i, t: MockJob(gen_unique_id(), t.name, [i], {})
+            ajobs = [cjob(i, TaskA) for i in xrange(10)]
+            bjobs = [cjob(i, TaskB) for i in xrange(10)]
+            cjobs = [cjob(i, TaskC) for i in xrange(10)]
+            djobs = [cjob(i, TaskD) for i in xrange(10)]
+            # Spread the jobs around.
+            jobs = list(chain(*izip(ajobs, bjobs, cjobs, djobs)))
+
+            [b.put(job) for job in jobs]
+            for i, job in enumerate(jobs):
+                sys.stderr.write("0")
+                self.assertTrue(b.get(), job)
+            self.assertEqual(i+1, len(jobs))
+        finally:
+            self.registry.unregister(TaskD)
+
+if __name__ == "__main__":
+    unittest.main()

+ 2 - 1
celery/tests/test_worker.py

@@ -155,7 +155,8 @@ class TestWorkController(unittest.TestCase):
 
     def test_attrs(self):
         worker = self.worker
-        self.assertTrue(isinstance(worker.bucket_queue, Queue))
+        self.assertTrue(hasattr(worker.bucket_queue, "get"))
+        self.assertTrue(hasattr(worker.bucket_queue, "put"))
         self.assertTrue(isinstance(worker.hold_queue, Queue))
         self.assertTrue(worker.periodic_work_controller)
         self.assertTrue(worker.pool)