Browse Source

Fixes rate limits. Closes #33166

Ask Solem 9 years ago
parent
commit
d23c8c297a
2 changed files with 63 additions and 34 deletions
  1. 29 23
      celery/tests/worker/test_consumer.py
  2. 34 11
      celery/worker/consumer/consumer.py

+ 29 - 23
celery/tests/worker/test_consumer.py

@@ -3,6 +3,8 @@ from __future__ import absolute_import, unicode_literals
 import errno
 import socket
 
+from collections import deque
+
 from billiard.exceptions import RestartFreqExceeded
 
 from celery.datastructures import LimitedSet
@@ -13,7 +15,9 @@ from celery.worker.consumer.heart import Heart
 from celery.worker.consumer.mingle import Mingle
 from celery.worker.consumer.tasks import Tasks
 
-from celery.tests.case import AppCase, ContextMock, Mock, call, patch, skip
+from celery.tests.case import (
+    AppCase, ContextMock, Mock, call, patch, skip,
+)
 
 
 class test_Consumer(AppCase):
@@ -103,30 +107,32 @@ class test_Consumer(AppCase):
 
     def test_limit_task(self):
         c = self.get_consumer()
+        c.timer = Mock()
 
-        with patch('celery.worker.consumer.consumer.task_reserved') as reserv:
-            bucket = Mock()
-            request = Mock()
-            bucket.can_consume.return_value = True
-
-            c._limit_task(request, bucket, 3)
-            bucket.can_consume.assert_called_with(3)
-            reserv.assert_called_with(request)
-            c.on_task_request.assert_called_with(request)
+        bucket = Mock()
+        request = Mock()
+        bucket.can_consume.return_value = True
+        bucket.contents = deque()
+
+        c._limit_task(request, bucket, 3)
+        bucket.can_consume.assert_called_with(3)
+        bucket.expected_time.assert_called_with(3)
+        c.timer.call_after.assert_called_with(
+            bucket.expected_time(), c._on_bucket_wakeup, (bucket, 3),
+            priority=c._limit_order,
+        )
 
-        with patch('celery.worker.consumer.consumer.task_reserved') as reserv:
-            bucket.can_consume.return_value = False
-            bucket.expected_time.return_value = 3.33
-            limit_order = c._limit_order
-            c._limit_task(request, bucket, 4)
-            self.assertEqual(c._limit_order, limit_order + 1)
-            bucket.can_consume.assert_called_with(4)
-            c.timer.call_after.assert_called_with(
-                3.33, c._limit_move_to_pool, (request,),
-                priority=c._limit_order,
-            )
-            bucket.expected_time.assert_called_with(4)
-            reserv.assert_not_called()
+        bucket.can_consume.return_value = False
+        bucket.expected_time.return_value = 3.33
+        limit_order = c._limit_order
+        c._limit_task(request, bucket, 4)
+        self.assertEqual(c._limit_order, limit_order + 1)
+        bucket.can_consume.assert_called_with(4)
+        c.timer.call_after.assert_called_with(
+            3.33, c._on_bucket_wakeup, (bucket, 4),
+            priority=c._limit_order,
+        )
+        bucket.expected_time.assert_called_with(4)
 
     def test_start_blueprint_raises_EMFILE(self):
         c = self.get_consumer()

+ 34 - 11
celery/worker/consumer/consumer.py

@@ -29,7 +29,7 @@ from celery import bootsteps
 from celery import signals
 from celery.app.trace import build_tracer
 from celery.exceptions import InvalidTaskError, NotRegistered
-from celery.five import buffer_t, items, python_2_unicode_compatible
+from celery.five import buffer_t, items, python_2_unicode_compatible, values
 from celery.utils import gethostname
 from celery.utils.functional import noop
 from celery.utils.log import get_logger
@@ -270,17 +270,37 @@ class Consumer(object):
         task_reserved(request)
         self.on_task_request(request)
 
-    def _limit_task(self, request, bucket, tokens):
-        if not bucket.can_consume(tokens):
-            hold = bucket.expected_time(tokens)
-            pri = self._limit_order = (self._limit_order + 1) % 10
-            self.timer.call_after(
-                hold, self._limit_move_to_pool, (request,),
-                priority=pri,
-            )
+    def _on_bucket_wakeup(self, bucket, tokens):
+        try:
+            request = bucket.pop()
+        except IndexError:
+            pass
         else:
-            task_reserved(request)
-            self.on_task_request(request)
+            self._limit_move_to_pool(request)
+            self._schedule_oldest_bucket_request(bucket, tokens)
+
+    def _schedule_oldest_bucket_request(self, bucket, tokens):
+        try:
+            request = bucket.pop()
+        except IndexError:
+            pass
+        else:
+            return self._schedule_bucket_request(request, bucket, tokens)
+
+    def _schedule_bucket_request(self, request, bucket, tokens):
+        bucket.can_consume(tokens)
+        bucket.add(request)
+        pri = self._limit_order = (self._limit_order + 1) % 10
+        hold = bucket.expected_time(tokens)
+        self.timer.call_after(
+            hold, self._on_bucket_wakeup, (bucket, tokens),
+            priority=pri,
+        )
+
+    def _limit_task(self, request, bucket, tokens):
+        if bucket.contents:
+            return bucket.add(request)
+        return self._schedule_bucket_request(request, bucket, tokens)
 
     def start(self):
         blueprint = self.blueprint
@@ -369,6 +389,9 @@ class Consumer(object):
             self.controller.semaphore.clear()
         if self.timer:
             self.timer.clear()
+        for bucket in values(self.task_buckets):
+            if bucket:
+                bucket.clear_pending()
         reserved_requests.clear()
         if self.pool and self.pool.flush:
             self.pool.flush()