|
@@ -134,26 +134,31 @@ class TaskBucket(object):
|
|
|
"""Initialize with buckets for all the task types in the registry."""
|
|
|
map(self.add_bucket_for_type, self.task_registry.keys())
|
|
|
|
|
|
+ def refresh(self):
|
|
|
+ """Refresh rate limits for all task types in the registry."""
|
|
|
+ map(self.update_bucket_for_type, self.task_registry.keys())
|
|
|
+
|
|
|
def get_bucket_for_type(self, task_name):
|
|
|
"""Get the bucket for a particular task type."""
|
|
|
if task_name not in self.buckets:
|
|
|
return self.add_bucket_for_type(task_name)
|
|
|
return self.buckets[task_name]
|
|
|
|
|
|
- def add_bucket_for_type(self, task_name):
|
|
|
- """Add a bucket for a task type.
|
|
|
-
|
|
|
- Will read the tasks rate limit and create a :class:`TokenBucketQueue`
|
|
|
- if it has one. If the task doesn't have a rate limit a regular Queue
|
|
|
- will be used.
|
|
|
+ def _get_queue_for_type(self, task_name):
|
|
|
+ bucket = self.buckets[task_name]
|
|
|
+ if isinstance(bucket, TokenBucketQueue):
|
|
|
+ return bucket.queue
|
|
|
+ return bucket
|
|
|
|
|
|
- """
|
|
|
- if task_name in self.buckets:
|
|
|
- return
|
|
|
+ def update_bucket_for_type(self, task_name):
|
|
|
task_type = self.task_registry[task_name]
|
|
|
- task_queue = task_type.rate_limit_queue_type()
|
|
|
rate_limit = getattr(task_type, "rate_limit", None)
|
|
|
rate_limit = parse_ratelimit_string(rate_limit)
|
|
|
+ if task_name in self.buckets:
|
|
|
+ task_queue = self._get_queue_for_type(task_name)
|
|
|
+ else:
|
|
|
+ task_queue = task_type.rate_limit_queue_type()
|
|
|
+
|
|
|
if rate_limit:
|
|
|
task_queue = TokenBucketQueue(rate_limit, queue=task_queue)
|
|
|
else:
|
|
@@ -162,6 +167,18 @@ class TaskBucket(object):
|
|
|
self.buckets[task_name] = task_queue
|
|
|
return task_queue
|
|
|
|
|
|
+ def add_bucket_for_type(self, task_name):
|
|
|
+ """Add a bucket for a task type.
|
|
|
+
|
|
|
+ Will read the tasks rate limit and create a :class:`TokenBucketQueue`
|
|
|
+ if it has one. If the task doesn't have a rate limit a regular Queue
|
|
|
+ will be used.
|
|
|
+
|
|
|
+ """
|
|
|
+ if task_name not in self.buckets:
|
|
|
+ return self.update_bucket_for_type(task_name)
|
|
|
+
|
|
|
+
|
|
|
def qsize(self):
|
|
|
"""Get the total size of all the queues."""
|
|
|
return sum(bucket.qsize() for bucket in self.buckets.values())
|