Browse Source

celeryd: basic_qos prefetch_count is a short, so can't exceed 0xffff. If this happens we disable the prefetch limit. Closes #359. Thanks to defcube

Ask Solem 14 years ago
parent
commit
fc11106b04
2 changed files with 28 additions and 5 deletions
  1. 16 1
      celery/tests/test_worker/test_worker.py
  2. 12 4
      celery/worker/consumer.py

+ 16 - 1
celery/tests/test_worker/test_worker.py

@@ -18,7 +18,7 @@ from celery.worker import WorkController
 from celery.worker.buckets import FastQueue
 from celery.worker.buckets import FastQueue
 from celery.worker.job import TaskRequest
 from celery.worker.job import TaskRequest
 from celery.worker.consumer import Consumer as MainConsumer
 from celery.worker.consumer import Consumer as MainConsumer
-from celery.worker.consumer import QoS, RUN
+from celery.worker.consumer import QoS, RUN, PREFETCH_COUNT_MAX
 from celery.utils.serialization import pickle
 from celery.utils.serialization import pickle
 
 
 from celery.tests.compat import catch_warnings
 from celery.tests.compat import catch_warnings
@@ -236,6 +236,21 @@ class test_QoS(unittest.TestCase):
         def qos(self, prefetch_size=0, prefetch_count=0, apply_global=False):
         def qos(self, prefetch_size=0, prefetch_count=0, apply_global=False):
             self.prefetch_count = prefetch_count
             self.prefetch_count = prefetch_count
 
 
+    def test_exceeds_short(self):
+        consumer = self.MockConsumer()
+        qos = QoS(consumer, PREFETCH_COUNT_MAX - 1,
+                current_app.log.get_default_logger())
+        qos.update()
+        self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1)
+        qos.increment()
+        self.assertEqual(qos.value, PREFETCH_COUNT_MAX)
+        qos.increment()
+        self.assertEqual(qos.value, PREFETCH_COUNT_MAX + 1)
+        qos.decrement()
+        self.assertEqual(qos.value, PREFETCH_COUNT_MAX)
+        qos.decrement()
+        self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1)
+
     def test_consumer_increment_decrement(self):
     def test_consumer_increment_decrement(self):
         consumer = self.MockConsumer()
         consumer = self.MockConsumer()
         qos = QoS(consumer, 10, current_app.log.get_default_logger())
         qos = QoS(consumer, 10, current_app.log.get_default_logger())

+ 12 - 4
celery/worker/consumer.py

@@ -90,6 +90,9 @@ from celery.worker.heartbeat import Heart
 RUN = 0x1
 RUN = 0x1
 CLOSE = 0x2
 CLOSE = 0x2
 
 
+#: Prefetch count can't exceed short.
+PREFETCH_COUNT_MAX = 0xFFFF
+
 
 
 class QoS(object):
 class QoS(object):
     """Quality of Service for Channel.
     """Quality of Service for Channel.
@@ -114,8 +117,8 @@ class QoS(object):
         self._mutex.acquire()
         self._mutex.acquire()
         try:
         try:
             if self.value:
             if self.value:
-                self.value += max(n, 0)
-                self.set(self.value)
+                new_value = self.value + max(n, 0)
+                self.value = self.set(new_value)
             return self.value
             return self.value
         finally:
         finally:
             self._mutex.release()
             self._mutex.release()
@@ -152,8 +155,13 @@ class QoS(object):
     def set(self, pcount):
     def set(self, pcount):
         """Set channel prefetch_count setting."""
         """Set channel prefetch_count setting."""
         if pcount != self.prev:
         if pcount != self.prev:
-            self.logger.debug("basic.qos: prefetch_count->%s" % pcount)
-            self.consumer.qos(prefetch_count=pcount)
+            new_value = pcount
+            if pcount > PREFETCH_COUNT_MAX:
+                self.logger.warning("QoS: Disabled: prefetch_count exceeds %r" % (
+                    PREFETCH_COUNT_MAX, ))
+                new_value = 0
+            self.logger.debug("basic.qos: prefetch_count->%s" % new_value)
+            self.consumer.qos(prefetch_count=new_value)
             self.prev = pcount
             self.prev = pcount
         return pcount
         return pcount