Browse Source

Merge branch 'tokenbucket' of github.com:ask/celery into tokenbucket

Ask Solem 15 years ago
parent
commit
064fe638b1
6 changed files with 558 additions and 9 deletions
  1. 299 0
      celery/buckets.py
  2. 23 1
      celery/conf.py
  3. 9 0
      celery/task/base.py
  4. 214 0
      celery/tests/test_buckets.py
  5. 2 1
      celery/tests/test_worker.py
  6. 11 7
      celery/worker/__init__.py

+ 299 - 0
celery/buckets.py

@@ -0,0 +1,299 @@
+import time
+from Queue import Queue
+from Queue import Empty as QueueEmpty
+
+RATE_MODIFIER_MAP = {"s": lambda n: n,
+                     "m": lambda n: n / 60.0,
+                     "h": lambda n: n / 60.0 / 60.0}
+
+BASE_IDENTIFIERS = {"0x": 16,
+                    "0o": 8,
+                    "0b": 2}
+
+
+class RateLimitExceeded(Exception):
+    """The token buckets rate limit has been exceeded."""
+
+
+def parse_ratelimit_string(rate_limit):
+    """Parse rate limit configurations such as ``"100/m"`` or ``"2/h"``
+        and convert them into seconds.
+
+    Returns ``0`` for no rate limit.
+
+    """
+
+    if rate_limit:
+        if isinstance(rate_limit, basestring):
+            base = BASE_IDENTIFIERS.get(rate_limit[:2], 10)
+            try:
+                return int(rate_limit, base)
+            except ValueError:
+                ops, _, modifier = rate_limit.partition("/")
+                return RATE_MODIFIER_MAP[modifier](int(ops, base)) or 0
+        return rate_limit or 0
+    return 0
+
+
+class TaskBucket(object):
+    """This is a collection of token buckets, each task type having
+    its own token bucket. If the task type doesn't have a rate limit,
+    it will have a plain Queue object instead of a token bucket queue.
+
+    The :meth:`put` operation forwards the task to its appropriate bucket,
+    while the :meth:`get` operation iterates over the buckets and retrieves
+    the first available item.
+
+    Say we have three types of tasks in the registry: ``celery.ping``,
+    ``feed.refresh`` and ``video.compress``, the TaskBucket will consist
+    of the following items::
+
+        {"celery.ping": TokenBucketQueue(fill_rate=300),
+         "feed.refresh": Queue(),
+         "video.compress": TokenBucketQueue(fill_rate=2)}
+
+    The get operation will iterate over these until one of the buckets
+    is able to return an item. The underlying datastructure is a ``dict``,
+    so the order is ignored here.
+
+    :param task_registry: The task registry used to get the task
+        type class for a given task name.
+
+
+    """
+    min_wait = 0.0
+
+    def __init__(self, task_registry):
+        self.task_registry = task_registry
+        self.buckets = {}
+        self.init_with_registry()
+        self.immediate = Queue()
+
+    def put(self, job):
+        """Put a task into the appropiate bucket."""
+        self.buckets[job.task_name].put_nowait(job)
+    put_nowait = put
+
+    def _get(self):
+        # If the first bucket is always returning items, we would never
+        # get to fetch items from the other buckets. So we always iterate over
+        # all the buckets and put any ready items into a queue called
+        # "immediate". This queue is always checked for cached items first.
+        if self.immediate:
+            try:
+                return 0, self.immediate.get_nowait()
+            except QueueEmpty:
+                pass
+
+        remaining_times = []
+        for bucket in self.buckets.values():
+            remaining = bucket.expected_time()
+            if not remaining:
+                try:
+                    # Just put any ready items into the immediate queue.
+                    self.immediate.put_nowait(bucket.get_nowait())
+                except QueueEmpty:
+                    pass
+                except RateLimitExceeded:
+                    remaining_times.append(bucket.expected_time())
+            else:
+                remaining_times.append(remaining)
+
+        # Try the immediate queue again.
+        try:
+            return 0, self.immediate.get_nowait()
+        except QueueEmpty:
+            if not remaining_times:
+                # No items in any of the buckets.
+                raise
+
+            # There's items, but have to wait before we can retrieve them,
+            # return the shortest remaining time.
+            return min(remaining_times), None
+
+    def get(self, block=True, timeout=None):
+        """Retrive the task from the first available bucket.
+
+        Available as in, there is an item in the queue and you can
+        consume tokens from it.
+
+        """
+        time_start = time.time()
+        did_timeout = lambda: timeout and time.time() - time_start > timeout
+
+        while True:
+            remaining_time, item = self._get()
+            if remaining_time:
+                if not block or did_timeout():
+                    raise QueueEmpty
+                time.sleep(remaining_time)
+            else:
+                return item
+
+    def get_nowait(self):
+        return self.get(block=False)
+
+    def init_with_registry(self):
+        """Initialize with buckets for all the task types in the registry."""
+        map(self.add_bucket_for_type, self.task_registry.keys())
+
+    def get_bucket_for_type(self, task_name):
+        """Get the bucket for a particular task type."""
+        if task_name not in self.buckets:
+            return self.add_bucket_for_type(task_name)
+        return self.buckets[task_name]
+
+    def add_bucket_for_type(self, task_name):
+        """Add a bucket for a task type.
+
+        Will read the tasks rate limit and create a :class:`TokenBucketQueue`
+        if it has one. If the task doesn't have a rate limit a regular Queue
+        will be used.
+
+        """
+        assert task_name not in self.buckets
+        task_type = self.task_registry[task_name]
+        task_queue = Queue()
+        rate_limit = getattr(task_type, "rate_limit", None)
+        rate_limit = parse_ratelimit_string(rate_limit)
+        if rate_limit:
+            task_queue = TokenBucketQueue(rate_limit, queue=task_queue)
+        else:
+            task_queue.expected_time = lambda: 0
+
+        self.buckets[task_name] = task_queue
+        return task_queue
+
+    def qsize(self):
+        """Get the total size of all the queues."""
+        return sum(bucket.qsize() for bucket in self.buckets.values())
+
+    def empty(self):
+        return all(bucket.empty() for bucket in self.buckets.values())
+
+
+class TokenBucketQueue(object):
+    """Queue with rate limited get operations.
+
+    This uses the token bucket algorithm to rate limit the queue on get
+    operations.
+    See http://en.wikipedia.org/wiki/Token_Bucket
+    Most of this code was stolen from an entry in the ASPN Python Cookbook:
+    http://code.activestate.com/recipes/511490/
+
+    :param fill_rate: see :attr:`fill_rate`.
+    :keyword capacity: see :attr:`capacity`.
+
+    .. attribute:: fill_rate
+
+        The rate in tokens/second that the bucket will be refilled.
+
+    .. attribute:: capacity
+
+        Maximum number of tokens in the bucket. Default is ``1``.
+
+    .. attribute:: timestamp
+
+        Timestamp of the last time a token was taken out of the bucket.
+
+    """
+    RateLimitExceeded = RateLimitExceeded
+
+    def __init__(self, fill_rate, queue=None, capacity=1):
+        self.capacity = float(capacity)
+        self._tokens = self.capacity
+        self.queue = queue
+        if not self.queue:
+            self.queue = Queue()
+        self.fill_rate = float(fill_rate)
+        self.timestamp = time.time()
+
+    def put(self, item, block=True):
+        """Put an item into the queue.
+
+        Also see :meth:`Queue.Queue.put`.
+
+        """
+        put = self.queue.put if block else self.queue.put_nowait
+        put(item)
+
+    def put_nowait(self, item):
+        """Put an item into the queue without blocking.
+
+        :raises Queue.Full: If a free slot is not immediately available.
+
+        Also see :meth:`Queue.Queue.put_nowait`
+
+        """
+        return self.put(item, block=False)
+
+    def get(self, block=True):
+        """Remove and return an item from the queue.
+
+        :raises RateLimitExceeded: If a token could not be consumed from the
+            token bucket (consuming from the queue too fast).
+        :raises Queue.Empty: If an item is not immediately available.
+
+        Also see :meth:`Queue.Queue.get`.
+
+        """
+        get = self.queue.get if block else self.queue.get_nowait
+
+        if not self.can_consume(1):
+            raise RateLimitExceeded
+
+        return get()
+
+    def get_nowait(self):
+        """Remove and return an item from the queue without blocking.
+
+        :raises RateLimitExceeded: If a token could not be consumed from the
+            token bucket (consuming from the queue too fast).
+        :raises Queue.Empty: If an item is not immediately available.
+
+        Also see :meth:`Queue.Queue.get_nowait`.
+
+        """
+        return self.get(block=False)
+
+    def qsize(self):
+        """Returns the size of the queue.
+
+        See :meth:`Queue.Queue.qsize`.
+
+        """
+        return self.queue.qsize()
+
+    def empty(self):
+        return self.queue.empty()
+
+    def wait(self, block=False):
+        """Wait until a token can be retrieved from the bucket and return
+        the next item."""
+        while True:
+            remaining = self.expected_time()
+            if not remaining:
+                return self.get(block=block)
+            time.sleep(remaining)
+
+    def can_consume(self, tokens=1):
+        """Consume tokens from the bucket. Returns True if there were
+        sufficient tokens otherwise False."""
+        if tokens <= self._get_tokens():
+            self._tokens -= tokens
+            return True
+        return False
+
+    def expected_time(self, tokens=1):
+        """Returns the expected time in seconds when a new token should be
+        available."""
+        tokens = max(tokens, self._get_tokens())
+        return (tokens - self._get_tokens()) / self.fill_rate
+
+    def _get_tokens(self):
+        if self._tokens < self.capacity:
+            now = time.time()
+            delta = self.fill_rate * (now - self.timestamp)
+            self._tokens = min(self.capacity, self._tokens + delta)
+            self.timestamp = now
+        return self._tokens

+ 23 - 1
celery/conf.py

@@ -22,6 +22,7 @@ DEFAULT_AMQP_CONNECTION_MAX_RETRIES = 100
 DEFAULT_TASK_SERIALIZER = "pickle"
 DEFAULT_BACKEND = "database"
 DEFAULT_PERIODIC_STATUS_BACKEND = "database"
+DEFAULT_DISABLE_RATE_LIMITS = False
 
 
 """
@@ -254,7 +255,6 @@ CELERY_PERIODIC_STATUS_BACKEND = getattr(settings,
                                     "CELERY_PERIODIC_STATUS_BACKEND",
                                     DEFAULT_PERIODIC_STATUS_BACKEND)
 
-
 """
 
 .. data:: CELERY_CACHE_BACKEND
@@ -264,3 +264,25 @@ cache backend in ``CACHE_BACKEND`` will be used.
 
 """
 CELERY_CACHE_BACKEND = getattr(settings, "CELERY_CACHE_BACKEND", None)
+
+
+"""
+
+.. data:: DEFAULT_RATE_LIMIT
+
+The default rate limit applied to all tasks which doesn't have a custom
+rate limit defined. (Default: None)
+
+"""
+DEFAULT_RATE_LIMIT = getattr(settings, "CELERY_DEFAULT_RATE_LIMIT", None)
+
+"""
+
+.. data:: DISABLE_RATE_LIMITS
+
+If ``True`` all rate limits will be disabled and all tasks will be executed
+as soon as possible.
+
+"""
+DISABLE_RATE_LIMITS = getattr(settings, "CELERY_DISABLE_RATE_LIMITS",
+                              DEFAULT_DISABLE_RATE_LIMITS)

+ 9 - 0
celery/task/base.py

@@ -114,6 +114,14 @@ class Task(object):
         Defeault time in seconds before a retry of the task should be
         executed. Default is a 1 minute delay.
 
+    .. rate_limit:: Set the rate limit for this task type,
+        if this is ``None`` no rate limit is in effect.
+        The rate limits can be specified in seconds, minutes or hours
+        by appending ``"/s"``, ``"/m"`` or "``/h"``". If this is an integer
+        it is interpreted as seconds. Example: ``"100/m" (hundred tasks a
+        minute). Default is the ``CELERY_DEFAULT_RATE_LIMIT`` setting (which
+        is off if not specified).
+
     .. attribute:: ignore_result
 
         Don't store the status and return value. This means you can't
@@ -178,6 +186,7 @@ class Task(object):
     max_retries = 3
     default_retry_delay = 3 * 60
     serializer = conf.TASK_SERIALIZER
+    rate_limit = conf.DEFAULT_RATE_LIMIT
 
     MaxRetriesExceededError = MaxRetriesExceededError
 

+ 214 - 0
celery/tests/test_buckets.py

@@ -0,0 +1,214 @@
+import sys
+import os
+sys.path.insert(0, os.getcwd())
+import unittest
+import time
+from celery import buckets
+from celery.task.base import Task
+from celery.registry import TaskRegistry
+from celery.utils import gen_unique_id
+from itertools import chain, izip
+
+
+class MockJob(object):
+
+    def __init__(self, task_id, task_name, args, kwargs):
+        self.task_id = task_id
+        self.task_name = task_name
+        self.args = args
+        self.kwargs = kwargs
+
+    def __eq__(self, other):
+        if isinstance(other, self.__class__):
+            return bool(self.task_id == other.task_id \
+                    and self.task_name == other.task_name \
+                    and self.args == other.args \
+                    and self.kwargs == other.kwargs)
+        else:
+            return self == other
+
+    def __repr__(self):
+        return "<MockJob: task:%s id:%s args:%s kwargs:%s" % (
+                self.task_name, self.task_id, self.args, self.kwargs)
+
+
+class TestTokenBucketQueue(unittest.TestCase):
+
+    def empty_queue_yields_QueueEmpty(self):
+        x = buckets.TokenBucketQueue(fill_rate=10)
+        self.assertRaises(buckets.QueueEmpty, x.get)
+
+    def test_bucket__put_get(self):
+        x = buckets.TokenBucketQueue(fill_rate=10)
+        x.put("The quick brown fox")
+        self.assertEquals(x.get(), "The quick brown fox")
+
+        x.put_nowait("The lazy dog")
+        time.sleep(0.2)
+        self.assertEquals(x.get_nowait(), "The lazy dog")
+
+    def test_fill_rate(self):
+        x = buckets.TokenBucketQueue(fill_rate=10)
+        # 20 items should take at least one second to complete
+        time_start = time.time()
+        [x.put(str(i)) for i in xrange(20)]
+        for i in xrange(20):
+            sys.stderr.write("x")
+            x.wait()
+        self.assertTrue(time.time() - time_start > 1.5)
+
+    def test_can_consume(self):
+        x = buckets.TokenBucketQueue(fill_rate=1)
+        x.put("The quick brown fox")
+        self.assertEqual(x.get(), "The quick brown fox")
+        time.sleep(0.1)
+        # Not yet ready for another token
+        x.put("The lazy dog")
+        self.assertRaises(x.RateLimitExceeded, x.get)
+
+    def test_expected_time(self):
+        x = buckets.TokenBucketQueue(fill_rate=1)
+        x.put_nowait("The quick brown fox")
+        self.assertEqual(x.get_nowait(), "The quick brown fox")
+        self.assertTrue(x.expected_time())
+
+    def test_qsize(self):
+        x = buckets.TokenBucketQueue(fill_rate=1)
+        x.put("The quick brown fox")
+        self.assertEqual(x.qsize(), 1)
+        self.assertTrue(x.get_nowait(), "The quick brown fox")
+
+
+class TestRateLimitString(unittest.TestCase):
+
+    def test_conversion(self):
+        self.assertEquals(buckets.parse_ratelimit_string(999), 999)
+        self.assertEquals(buckets.parse_ratelimit_string("1456/s"), 1456)
+        self.assertEquals(buckets.parse_ratelimit_string("100/m"),
+                          100 / 60.0)
+        self.assertEquals(buckets.parse_ratelimit_string("10/h"),
+                          10 / 60.0 / 60.0)
+        self.assertEquals(buckets.parse_ratelimit_string("0xffec/s"), 0xffec)
+        self.assertEquals(buckets.parse_ratelimit_string("0xcda/m"), 
+                          0xcda / 60.0)
+        self.assertEquals(buckets.parse_ratelimit_string("0xF/h"),
+                          0xf / 60.0 / 60.0)
+
+        for zero in ("0x0", "0b0", "0o0", 0, None, "0/m", "0/h", "0/s"):
+            self.assertEquals(buckets.parse_ratelimit_string(zero), 0)
+
+
+class TaskA(Task):
+    rate_limit = 10
+
+
+class TaskB(Task):
+    rate_limit = None
+
+
+class TaskC(Task):
+    rate_limit = "1/s"
+
+
+class TaskD(Task):
+    rate_limit = "1000/m"
+
+
+class TestTaskBuckets(unittest.TestCase):
+
+    def setUp(self):
+        self.registry = TaskRegistry()
+        self.task_classes = (TaskA, TaskB, TaskC)
+        for task_cls in self.task_classes:
+            self.registry.register(task_cls)
+
+    def test_auto_add_on_missing(self):
+        b = buckets.TaskBucket(task_registry=self.registry)
+        for task_cls in self.task_classes:
+            self.assertTrue(task_cls.name in b.buckets.keys())
+        self.registry.register(TaskD)
+        self.assertTrue(b.get_bucket_for_type(TaskD.name))
+        self.assertTrue(TaskD.name in b.buckets.keys())
+        self.registry.unregister(TaskD)
+
+    def test_has_rate_limits(self):
+        b = buckets.TaskBucket(task_registry=self.registry)
+        self.assertEqual(b.buckets[TaskA.name].fill_rate, 10)
+        self.assertTrue(isinstance(b.buckets[TaskB.name], buckets.Queue))
+        self.assertEqual(b.buckets[TaskC.name].fill_rate, 1)
+        self.registry.register(TaskD)
+        b.init_with_registry()
+        try:
+            self.assertEqual(b.buckets[TaskD.name].fill_rate, 1000 / 60.0)
+        finally:
+            self.registry.unregister(TaskD)
+
+    def test_on_empty_buckets__get_raises_empty(self):
+        b = buckets.TaskBucket(task_registry=self.registry)
+        self.assertRaises(buckets.QueueEmpty, b.get)
+        self.assertEqual(b.qsize(), 0)
+
+    def test_put__get(self):
+        b = buckets.TaskBucket(task_registry=self.registry)
+        job = MockJob(gen_unique_id(), TaskA.name, ["theqbf"], {"foo": "bar"})
+        b.put(job)
+        self.assertEquals(b.get(), job)
+
+    def test_fill_rate(self):
+        b = buckets.TaskBucket(task_registry=self.registry)
+
+        cjob = lambda i: MockJob(gen_unique_id(), TaskA.name, [i], {})
+        jobs = [cjob(i) for i in xrange(20)]
+        [b.put(job) for job in jobs]
+
+        self.assertEqual(b.qsize(), 20)
+
+        # 20 items should take at least one second to complete
+        time_start = time.time()
+        for i, job in enumerate(jobs):
+            sys.stderr.write("i")
+            self.assertEqual(b.get(), job)
+        self.assertTrue(time.time() - time_start > 1.5)
+
+    def test__very_busy_queue_doesnt_block_others(self):
+        b = buckets.TaskBucket(task_registry=self.registry)
+
+        cjob = lambda i, t: MockJob(gen_unique_id(), t.name, [i], {})
+        ajobs = [cjob(i, TaskA) for i in xrange(10)]
+        bjobs = [cjob(i, TaskB) for i in xrange(20)]
+        jobs = list(chain(*izip(bjobs, ajobs)))
+        map(b.put, jobs)
+
+        got_ajobs = 0
+        for job in (b.get() for i in xrange(20)):
+            if job.task_name == TaskA.name:
+                got_ajobs += 1
+
+        self.assertTrue(got_ajobs > 2)
+
+
+    def test_thorough__multiple_types(self):
+        self.registry.register(TaskD)
+        try:
+            b = buckets.TaskBucket(task_registry=self.registry)
+
+            cjob = lambda i, t: MockJob(gen_unique_id(), t.name, [i], {})
+
+            ajobs = [cjob(i, TaskA) for i in xrange(10)]
+            bjobs = [cjob(i, TaskB) for i in xrange(10)]
+            cjobs = [cjob(i, TaskC) for i in xrange(10)]
+            djobs = [cjob(i, TaskD) for i in xrange(10)]
+
+            # Spread the jobs around.
+            jobs = list(chain(*izip(ajobs, bjobs, cjobs, djobs)))
+
+            [b.put(job) for job in jobs]
+            for i, job in enumerate(jobs):
+                sys.stderr.write("0")
+                self.assertTrue(b.get(), job)
+            self.assertEqual(i+1, len(jobs))
+        finally:
+            self.registry.unregister(TaskD)
+
+if __name__ == "__main__":
+    unittest.main()

+ 2 - 1
celery/tests/test_worker.py

@@ -158,7 +158,8 @@ class TestWorkController(unittest.TestCase):
 
     def test_attrs(self):
         worker = self.worker
-        self.assertTrue(isinstance(worker.bucket_queue, Queue))
+        self.assertTrue(hasattr(worker.bucket_queue, "get"))
+        self.assertTrue(hasattr(worker.bucket_queue, "put"))
         self.assertTrue(isinstance(worker.hold_queue, Queue))
         self.assertTrue(worker.periodic_work_controller)
         self.assertTrue(worker.pool)

+ 11 - 7
celery/worker/__init__.py

@@ -8,12 +8,13 @@ from celery.worker.controllers import Mediator, PeriodicWorkController
 from celery.worker.job import TaskWrapper
 from celery.exceptions import NotRegistered
 from celery.messaging import get_consumer_set
-from celery.conf import DAEMON_CONCURRENCY, DAEMON_LOG_FILE
-from celery.conf import AMQP_CONNECTION_RETRY, AMQP_CONNECTION_MAX_RETRIES
 from celery.log import setup_logger
 from celery.pool import TaskPool
 from celery.utils import retry_over_time
 from celery.datastructures import SharedCounter
+from celery import registry
+from celery import conf
+from celery.buckets import TaskBucket
 from Queue import Queue
 import traceback
 import logging
@@ -153,12 +154,12 @@ class AMQPListener(object):
             connected = conn.connection # Connection is established lazily.
             return conn
 
-        if not AMQP_CONNECTION_RETRY:
+        if not conf.AMQP_CONNECTION_RETRY:
             return _establish_connection()
 
         conn = retry_over_time(_establish_connection, (socket.error, IOError),
                                errback=_connection_error_handler,
-                               max_retries=AMQP_CONNECTION_MAX_RETRIES)
+                               max_retries=conf.AMQP_CONNECTION_MAX_RETRIES)
         self.logger.debug("AMQPListener: Connection Established.")
         return conn
 
@@ -222,8 +223,8 @@ class WorkController(object):
 
     """
     loglevel = logging.ERROR
-    concurrency = DAEMON_CONCURRENCY
-    logfile = DAEMON_LOG_FILE
+    concurrency = conf.DAEMON_CONCURRENCY
+    logfile = conf.DAEMON_LOG_FILE
     _state = None
 
     def __init__(self, concurrency=None, logfile=None, loglevel=None,
@@ -237,7 +238,10 @@ class WorkController(object):
         self.logger = setup_logger(loglevel, logfile)
 
         # Queues
-        self.bucket_queue = Queue()
+        if conf.DISABLE_RATE_LIMITS:
+            self.bucket_queue = Queue()
+        else:
+            self.bucket_queue = TaskBucket(task_registry=registry.tasks)
         self.hold_queue = Queue()
 
         self.logger.debug("Instantiating thread components...")