Parcourir la source

celeryd: basic_qos prefetch_count is a short, so can't exceed 0xffff. If this happens we disable the prefetch limit. Closes #359. Thanks to defcube

Ask Solem il y a 14 ans
Parent
commit
fc11106b04
2 fichiers modifiés avec 28 ajouts et 5 suppressions
  1. 16 1
      celery/tests/test_worker/test_worker.py
  2. 12 4
      celery/worker/consumer.py

+ 16 - 1
celery/tests/test_worker/test_worker.py

@@ -18,7 +18,7 @@ from celery.worker import WorkController
 from celery.worker.buckets import FastQueue
 from celery.worker.job import TaskRequest
 from celery.worker.consumer import Consumer as MainConsumer
-from celery.worker.consumer import QoS, RUN
+from celery.worker.consumer import QoS, RUN, PREFETCH_COUNT_MAX
 from celery.utils.serialization import pickle
 
 from celery.tests.compat import catch_warnings
@@ -236,6 +236,21 @@ class test_QoS(unittest.TestCase):
         def qos(self, prefetch_size=0, prefetch_count=0, apply_global=False):
             self.prefetch_count = prefetch_count
 
+    def test_exceeds_short(self):
+        consumer = self.MockConsumer()
+        qos = QoS(consumer, PREFETCH_COUNT_MAX - 1,
+                current_app.log.get_default_logger())
+        qos.update()
+        self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1)
+        qos.increment()
+        self.assertEqual(qos.value, PREFETCH_COUNT_MAX)
+        qos.increment()
+        self.assertEqual(qos.value, PREFETCH_COUNT_MAX + 1)
+        qos.decrement()
+        self.assertEqual(qos.value, PREFETCH_COUNT_MAX)
+        qos.decrement()
+        self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1)
+
     def test_consumer_increment_decrement(self):
         consumer = self.MockConsumer()
         qos = QoS(consumer, 10, current_app.log.get_default_logger())

+ 12 - 4
celery/worker/consumer.py

@@ -90,6 +90,9 @@ from celery.worker.heartbeat import Heart
 RUN = 0x1
 CLOSE = 0x2
 
+#: Prefetch count can't exceed short.
+PREFETCH_COUNT_MAX = 0xFFFF
+
 
 class QoS(object):
     """Quality of Service for Channel.
@@ -114,8 +117,8 @@ class QoS(object):
         self._mutex.acquire()
         try:
             if self.value:
-                self.value += max(n, 0)
-                self.set(self.value)
+                new_value = self.value + max(n, 0)
+                self.value = self.set(new_value)
             return self.value
         finally:
             self._mutex.release()
@@ -152,8 +155,13 @@ class QoS(object):
     def set(self, pcount):
         """Set channel prefetch_count setting."""
         if pcount != self.prev:
-            self.logger.debug("basic.qos: prefetch_count->%s" % pcount)
-            self.consumer.qos(prefetch_count=pcount)
+            new_value = pcount
+            if pcount > PREFETCH_COUNT_MAX:
+                self.logger.warning("QoS: Disabled: prefetch_count exceeds %r" % (
+                    PREFETCH_COUNT_MAX, ))
+                new_value = 0
+            self.logger.debug("basic.qos: prefetch_count->%s" % new_value)
+            self.consumer.qos(prefetch_count=new_value)
             self.prev = pcount
         return pcount