Browse Source

celery.datastructures.TokenBucket: Generic Token Bucket algorithm

Ask Solem 14 years ago
parent
commit
267ede3385
3 changed files with 63 additions and 50 deletions
  1. 52 0
      celery/datastructures.py
  2. 4 7
      celery/tests/test_buckets.py
  3. 7 43
      celery/worker/buckets.py

+ 52 - 0
celery/datastructures.py

@@ -265,3 +265,55 @@ class LocalCache(OrderedDict):
         while len(self) >= self.limit:
             self.popitem(last=False)
         super(LocalCache, self).__setitem__(key, value)
+
+
+class TokenBucket(object):
+    """Token Bucket Algorithm.
+
+    See http://en.wikipedia.org/wiki/Token_Bucket
+    Most of this code was stolen from an entry in the ASPN Python Cookbook:
+    http://code.activestate.com/recipes/511490/
+
+    :param fill_rate: see :attr:`fill_rate`.
+    :keyword capacity: see :attr:`capacity`.
+
+    .. attribute:: fill_rate
+
+        The rate in tokens/second that the bucket will be refilled.
+
+    .. attribute:: capacity
+
+        Maximum number of tokens in the bucket. Default is ``1``.
+
+    .. attribute:: timestamp
+
+        Timestamp of the last time a token was taken out of the bucket.
+
+    """
+
+    def __init__(self, fill_rate, capacity=1):
+        self.capacity = float(capacity)
+        self._tokens = capacity
+        self.fill_rate = float(fill_rate)
+        self.timestamp = time.time()
+
+    def can_consume(self, tokens=1):
+        if tokens <= self._get_tokens():
+            self._tokens -= tokens
+            return True
+        return False
+
+    def expected_time(self, tokens=1):
+        """Returns the expected time in seconds when a new token should be
+        available. *Note: consumes a token from the bucket*"""
+        _tokens = self._get_tokens()
+        tokens = max(tokens, _tokens)
+        return (tokens - _tokens) / self.fill_rate
+
+    def _get_tokens(self):
+        if self._tokens < self.capacity:
+            now = time.time()
+            delta = self.fill_rate * (now - self.timestamp)
+            self._tokens = min(self.capacity, self._tokens + delta)
+            self.timestamp = now
+        return self._tokens

+ 4 - 7
celery/tests/test_buckets.py

@@ -183,13 +183,14 @@ class test_TaskBucket(unittest.TestCase):
     @skip_if_disabled
     def test_has_rate_limits(self):
         b = buckets.TaskBucket(task_registry=self.registry)
-        self.assertEqual(b.buckets[TaskA.name].fill_rate, 10)
+        self.assertEqual(b.buckets[TaskA.name]._bucket.fill_rate, 10)
         self.assertIsInstance(b.buckets[TaskB.name], buckets.Queue)
-        self.assertEqual(b.buckets[TaskC.name].fill_rate, 1)
+        self.assertEqual(b.buckets[TaskC.name]._bucket.fill_rate, 1)
         self.registry.register(TaskD)
         b.init_with_registry()
         try:
-            self.assertEqual(b.buckets[TaskD.name].fill_rate, 1000 / 60.0)
+            self.assertEqual(b.buckets[TaskD.name]._bucket.fill_rate,
+                             1000 / 60.0)
         finally:
             self.registry.unregister(TaskD)
 
@@ -284,10 +285,6 @@ class test_TaskBucket(unittest.TestCase):
 
 class test_FastQueue(unittest.TestCase):
 
-    def test_can_consume(self):
-        x = buckets.FastQueue()
-        self.assertTrue(x.can_consume())
-
     def test_items(self):
         x = buckets.FastQueue()
         x.put(10)

+ 7 - 43
celery/worker/buckets.py

@@ -3,6 +3,7 @@ import time
 from collections import deque
 from Queue import Queue, Empty as QueueEmpty
 
+from celery.datastructures import TokenBucket
 from celery.utils import all
 from celery.utils import timeutils
 from celery.utils.compat import izip_longest, chain_from_iterable
@@ -194,9 +195,6 @@ class FastQueue(Queue):
     def expected_time(self, tokens=1):
         return 0
 
-    def can_consume(self, tokens=1):
-        return True
-
     def wait(self, block=True):
         return self.get(block=block)
 
@@ -210,36 +208,19 @@ class TokenBucketQueue(object):
 
     This uses the token bucket algorithm to rate limit the queue on get
     operations.
-    See http://en.wikipedia.org/wiki/Token_Bucket
-    Most of this code was stolen from an entry in the ASPN Python Cookbook:
-    http://code.activestate.com/recipes/511490/
-
-    :param fill_rate: see :attr:`fill_rate`.
-    :keyword capacity: see :attr:`capacity`.
-
-    .. attribute:: fill_rate
-
-        The rate in tokens/second that the bucket will be refilled.
 
-    .. attribute:: capacity
-
-        Maximum number of tokens in the bucket. Default is ``1``.
-
-    .. attribute:: timestamp
-
-        Timestamp of the last time a token was taken out of the bucket.
+    :param fill_rate: The rate in tokens/second that the bucket will
+      be refilled.
+    :keyword capacity: Maximum number of tokens in the bucket. Default is 1.
 
     """
     RateLimitExceeded = RateLimitExceeded
 
     def __init__(self, fill_rate, queue=None, capacity=1):
-        self.capacity = float(capacity)
-        self._tokens = self.capacity
+        self._bucket = TokenBucket(fill_rate, capacity)
         self.queue = queue
         if not self.queue:
             self.queue = Queue()
-        self.fill_rate = float(fill_rate)
-        self.timestamp = time.time()
 
     def put(self, item, block=True):
         """Put an item into the queue.
@@ -271,7 +252,7 @@ class TokenBucketQueue(object):
         """
         get = block and self.queue.get or self.queue.get_nowait
 
-        if not self.can_consume(1):
+        if not self._bucket.can_consume(1):
             raise RateLimitExceeded()
 
         return get()
@@ -311,27 +292,10 @@ class TokenBucketQueue(object):
                 return self.get(block=block)
             time.sleep(remaining)
 
-    def can_consume(self, tokens=1):
-        """Consume tokens from the bucket. Returns True if there were
-        sufficient tokens otherwise False."""
-        if tokens <= self._get_tokens():
-            self._tokens -= tokens
-            return True
-        return False
-
     def expected_time(self, tokens=1):
         """Returns the expected time in seconds when a new token should be
         available."""
-        tokens = max(tokens, self._get_tokens())
-        return (tokens - self._get_tokens()) / self.fill_rate
-
-    def _get_tokens(self):
-        if self._tokens < self.capacity:
-            now = time.time()
-            delta = self.fill_rate * (now - self.timestamp)
-            self._tokens = min(self.capacity, self._tokens + delta)
-            self.timestamp = now
-        return self._tokens
+        return self._bucket.expected_time(tokens)
 
     @property
     def items(self):