Explorar el Código

Implementation of rate limits per. task using token buckets.

Ask Solem hace 15 años
padre
commit
8ff75a73c0
Se han modificado 4 ficheros con 215 adiciones y 8 borrados
  1. 173 0
      celery/buckets.py
  2. 23 1
      celery/conf.py
  3. 8 0
      celery/task/base.py
  4. 11 7
      celery/worker/__init__.py

+ 173 - 0
celery/buckets.py

@@ -0,0 +1,173 @@
+import time
+from Queue import Queue, Empty as QueueEmpty
+
+RATE_MODIFIER_MAP = {"s": lambda n: n,
+                     "m": lambda n: n / 60.0,
+                     "h": lambda n: n / 60.0 / 60.0}
+
+
+class BucketRateExceeded(Exception):
+    """The token buckets rate limit has been exceeded."""
+
+
+def parse_ratelimit_string(rate_limit):
+    """Parse rate limit configurations such as ``"100/m"`` or ``"2/h"``
+        and convert them into seconds."""
+    if rate_limit:
+        try:
+            return int(rate_limit)
+        except ValueError:
+            ops, _, modifier = rate_limit.partition("/")
+            return RATE_MODIFIER_MAP[modifier](int(ops))
+    return None
+
+
+class TaskBucket(object):
+    """A bucket with buckets of tasks. (eh. seriously.)
+
+    This is a collection of token buckets, each task type having
+    its own token bucket. If the task type doesn't have a rate limit,
+    it will have a plain Queue object instead of a token bucket queue.
+
+    The :meth:`put` operation forwards the task to its appropriate bucket,
+    while the :meth:`get` operation iterates over the buckets and retrieves
+    the first available item.
+
+    Say we have three types of tasks in the registry: ``celery.ping``,
+    ``feed.refresh`` and ``video.compress``, the TaskBucket will consist
+    of the following items::
+
+        {"celery.ping": TokenBucketQueue(fill_rate=300),
+         "feed.refresh": Queue(),
+         "video.compress": TokenBucketQueue(fill_rate=2)}
+
+    The get operation will iterate over these until one of them
+    is able to return an item. The underlying datastructure is a ``dict``,
+    so the order is ignored here.
+
+    :param task_registry: The task registry used to get the task
+        type class for a given task name.
+
+
+    """
+
+    def __init__(self, task_registry):
+        self.task_registry = task_registry
+        self.buckets = {}
+        self.init_with_registry()
+
+    def put(self, task):
+        """Put a task into the appropiate bucket."""
+        self.buckets[task_name].put_nowait(task)
+
+    def get(self):
+        """Retrive the task from the first available bucket.
+
+        Available as in, there is an item in the queue and you can
+        consume tokens from it.
+
+        """
+        for bucket in self.buckets.values():
+            try:
+                item = bucket.get_nowait()
+            except (BucketRateExceeded, QueueEmpty):
+                pass
+            time.sleep(0.01)
+
+    def init_with_registry(self):
+        """Initialize with buckets for all the task types in the registry."""
+        map(self.add_bucket_for_type, self.task_registry.keys())
+
+    def get_bucket_for_type(self, task_name):
+        """Get the bucket for a particular task type."""
+        if not task_name in self.buckets:
+            return self.add_bucket_for_type(task_name)
+        return self.buckets[task_name]
+
+    def add_bucket_for_type(self, task_name):
+        """Add a bucket for a task type.
+
+        Will read the tasks rate limit and create a :class:`TokenBucketQueue`
+        if it has one. If the task doesn't have a rate limit a regular Queue
+        will be used.
+
+        """
+        task_type = self.task_registry[task_name]
+        task_queue = Queue()
+        rate_limit = parse_ratelimit_string(task_type.rate_limit)
+        if rate_limit:
+            task_queue = TokenBucketQueue(rate_limit, queue=task_queue)
+
+        self.buckets[task_name] = task_queue
+        return bucket
+
+
+class TokenBucketQueue(object):
+    """An implementation of the token bucket algorithm.
+
+    See http://en.wikipedia.org/wiki/Token_Bucket
+    Most of this code was stolen from an entry in the ASPN Python Cookbook:
+    http://code.activestate.com/recipes/511490/
+
+    :param fill_rate: see :attr:`fill_rate`.
+    :keyword capacity: see :attr:`capacity`.
+
+    .. attribute:: fill_rate
+
+        The rate in tokens/second that the bucket will be refilled.
+
+    .. attribute:: capacity
+
+        Maximum number of tokens in the bucket. Default is ``1``.
+
+    .. attribute:: timestamp
+
+        Timestamp of the last time a token was taken out of the bucket.
+
+    """
+    def __init__(self, fill_rate, queue=None, capacity=1):
+        self.capacity = float(capacity)
+        self._tokens = self.capacity
+        self.queue = queue
+        if not self.queue:
+            self.queue = Queue()
+        self.fill_rate = float(fill_rate)
+        self.timestamp = time.time()
+
+    def put(self, item, nb=True):
+        put = self.queue.put_nowait if nb else self.queue.put
+        put(item)
+
+    def get(self, nb=True):
+        get = self.queue.get_nowait if nb else self.queue.get
+
+        if not self.can_consume(1):
+            raise BucketRateExceeded()
+
+        return get()
+
+    def get_nowait(self):
+        return self.get(nb=True)
+
+    def put_nowait(self):
+        return self.put(nb=True)
+
+    def qsize(self):
+        return self.queue.qsize()
+
+    def can_consume(self, tokens=1):
+        """Consume tokens from the bucket. Returns True if there were
+        sufficient tokens otherwise False."""
+        if tokens <= self._get_tokens():
+            self._tokens -= tokens
+        else:
+            return False
+        return True
+
+    def _get_tokens(self):
+        if self._tokens < self.capacity:
+            now = time.time()
+            delta = self.fill_rate * (now - self.timestamp)
+            self._tokens = min(self.capacity, self._tokens + delta)
+            self.timestamp = now
+        return self._tokens

+ 23 - 1
celery/conf.py

@@ -22,6 +22,7 @@ DEFAULT_AMQP_CONNECTION_MAX_RETRIES = 100
 DEFAULT_TASK_SERIALIZER = "pickle"
 DEFAULT_BACKEND = "database"
 DEFAULT_PERIODIC_STATUS_BACKEND = "database"
+DEFAULT_DISABLE_RATE_LIMITS = False
 
 
 """
@@ -254,7 +255,6 @@ CELERY_PERIODIC_STATUS_BACKEND = getattr(settings,
                                     "CELERY_PERIODIC_STATUS_BACKEND",
                                     DEFAULT_PERIODIC_STATUS_BACKEND)
 
-
 """
 
 .. data:: CELERY_CACHE_BACKEND
@@ -264,3 +264,25 @@ cache backend in ``CACHE_BACKEND`` will be used.
 
 """
 CELERY_CACHE_BACKEND = getattr(settings, "CELERY_CACHE_BACKEND", None)
+
+
+"""
+
+.. data:: DEFAULT_RATE_LIMIT
+
+The default rate limit applied to all tasks which doesn't have a custom
+rate limit defined. (Default: None)
+
+"""
+DEFAULT_RATE_LIMIT = getattr(settings, "CELERY_DEFAULT_RATE_LIMIT", None)
+
+"""
+
+.. data:: DISABLE_RATE_LIMITS
+
+If ``True`` all rate limits will be disabled and all tasks will be executed
+as soon as possible.
+
+"""
+DISABLE_RATE_LIMITS = getattr(settings, "CELERY_DISABLE_RATE_LIMITS",
+                              DEFAULT_DISABLE_RATE_LIMITS)

+ 8 - 0
celery/task/base.py

@@ -70,6 +70,13 @@ class Task(object):
         Defeault time in seconds before a retry of the task should be
         executed. Default is a 1 minute delay.
 
+    .. rate_limit:: Set the rate limit for this task type,
+        if this is ``None`` no rate limit is in effect.
+        The rate limits can be specified in seconds, minutes or hours
+        by appending ``"/s"``, ``"/m"`` or "``/h"``". If this is an integer
+        it is interpreted as seconds. Example: ``"100/m" (hundred tasks a
+        minute).
+
     .. attribute:: ignore_result
 
         Don't store the status and return value. This means you can't
@@ -138,6 +145,7 @@ class Task(object):
     max_retries = 3
     default_retry_delay = 3 * 60
     serializer = conf.TASK_SERIALIZER
+    rate_limit = None
 
     MaxRetriesExceededError = MaxRetriesExceededError
 

+ 11 - 7
celery/worker/__init__.py

@@ -8,12 +8,13 @@ from celery.worker.controllers import Mediator, PeriodicWorkController
 from celery.worker.job import TaskWrapper
 from celery.exceptions import NotRegistered
 from celery.messaging import get_consumer_set
-from celery.conf import DAEMON_CONCURRENCY, DAEMON_LOG_FILE
-from celery.conf import AMQP_CONNECTION_RETRY, AMQP_CONNECTION_MAX_RETRIES
 from celery.log import setup_logger
 from celery.pool import TaskPool
 from celery.utils import retry_over_time
 from celery.datastructures import SharedCounter
+from celery import registry
+from celery import conf
+from celery.buckets import TaskBucket
 from Queue import Queue
 import traceback
 import logging
@@ -153,12 +154,12 @@ class AMQPListener(object):
             connected = conn.connection # Connection is established lazily.
             return conn
 
-        if not AMQP_CONNECTION_RETRY:
+        if not conf.AMQP_CONNECTION_RETRY:
             return _establish_connection()
 
         conn = retry_over_time(_establish_connection, (socket.error, IOError),
                                errback=_connection_error_handler,
-                               max_retries=AMQP_CONNECTION_MAX_RETRIES)
+                               max_retries=conf.AMQP_CONNECTION_MAX_RETRIES)
         self.logger.debug("AMQPListener: Connection Established.")
         return conn
 
@@ -222,8 +223,8 @@ class WorkController(object):
 
     """
     loglevel = logging.ERROR
-    concurrency = DAEMON_CONCURRENCY
-    logfile = DAEMON_LOG_FILE
+    concurrency = conf.DAEMON_CONCURRENCY
+    logfile = conf.DAEMON_LOG_FILE
     _state = None
 
     def __init__(self, concurrency=None, logfile=None, loglevel=None,
@@ -237,7 +238,10 @@ class WorkController(object):
         self.logger = setup_logger(loglevel, logfile)
 
         # Queues
-        self.bucket_queue = Queue()
+        if conf.DISABLE_RATE_LIMITS:
+            self.bucket_queue = Queue()
+        else:
+            self.bucket_queue = TaskBucket(task_registry=registry.tasks)
         self.hold_queue = Queue()
 
         self.logger.debug("Instantiating thread components...")