ソースを参照

Optimizes prefetch count increments.

``qos.increment`` is now ``qos.increment_eventually``,
like decrement_eventually this will only change the channels prefetch count
when ``qos.update()`` is called.

``qos.update`` is now only called when there are no more messages to receive,
not for every message received as previously:
This enables us to group qos calls together, so that it only e.g
calls basic_qos for every 10 increments, instead of 10 suceeding calls within
the same millisecond.

This also fixes a bug where the worker could not respond to remote control commands
while it was busy receiving lots of ETA/countdown tasks.
Ask Solem 12 年 前
コミット
7e4ca7009f
2 ファイル変更48 行追加46 行削除
  1. 28 24
      celery/tests/worker/test_worker.py
  2. 20 22
      celery/worker/consumer.py

+ 28 - 24
celery/tests/worker/test_worker.py

@@ -111,29 +111,27 @@ class test_QoS(Case):
 
     def test_qos_increment_decrement(self):
         qos = self._QoS(10)
-        self.assertEqual(qos.increment(), 11)
-        self.assertEqual(qos.increment(3), 14)
-        self.assertEqual(qos.increment(-30), 14)
-        self.assertEqual(qos.decrement(7), 7)
-        self.assertEqual(qos.decrement(), 6)
-        with self.assertRaises(AssertionError):
-            qos.decrement(10)
+        self.assertEqual(qos.increment_eventually(), 11)
+        self.assertEqual(qos.increment_eventually(3), 14)
+        self.assertEqual(qos.increment_eventually(-30), 14)
+        self.assertEqual(qos.decrement_eventually(7), 7)
+        self.assertEqual(qos.decrement_eventually(), 6)
 
     def test_qos_disabled_increment_decrement(self):
         qos = self._QoS(0)
-        self.assertEqual(qos.increment(), 0)
-        self.assertEqual(qos.increment(3), 0)
-        self.assertEqual(qos.increment(-30), 0)
-        self.assertEqual(qos.decrement(7), 0)
-        self.assertEqual(qos.decrement(), 0)
-        self.assertEqual(qos.decrement(10), 0)
+        self.assertEqual(qos.increment_eventually(), 0)
+        self.assertEqual(qos.increment_eventually(3), 0)
+        self.assertEqual(qos.increment_eventually(-30), 0)
+        self.assertEqual(qos.decrement_eventually(7), 0)
+        self.assertEqual(qos.decrement_eventually(), 0)
+        self.assertEqual(qos.decrement_eventually(10), 0)
 
     def test_qos_thread_safe(self):
         qos = self._QoS(10)
 
         def add():
             for i in xrange(1000):
-                qos.increment()
+                qos.increment_eventually()
 
         def sub():
             for i in xrange(1000):
@@ -158,13 +156,13 @@ class test_QoS(Case):
         qos = QoS(Mock(), PREFETCH_COUNT_MAX - 1)
         qos.update()
         self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1)
-        qos.increment()
+        qos.increment_eventually()
         self.assertEqual(qos.value, PREFETCH_COUNT_MAX)
-        qos.increment()
+        qos.increment_eventually()
         self.assertEqual(qos.value, PREFETCH_COUNT_MAX + 1)
-        qos.decrement()
+        qos.decrement_eventually()
         self.assertEqual(qos.value, PREFETCH_COUNT_MAX)
-        qos.decrement()
+        qos.decrement_eventually()
         self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1)
 
     def test_consumer_increment_decrement(self):
@@ -173,7 +171,8 @@ class test_QoS(Case):
         qos.update()
         self.assertEqual(qos.value, 10)
         consumer.qos.assert_called_with(prefetch_count=10)
-        qos.decrement()
+        qos.decrement_eventually()
+        qos.update()
         self.assertEqual(qos.value, 9)
         consumer.qos.assert_called_with(prefetch_count=9)
         qos.decrement_eventually()
@@ -183,9 +182,9 @@ class test_QoS(Case):
 
         # Does not decrement 0 value
         qos.value = 0
-        qos.decrement()
+        qos.decrement_eventually()
         self.assertEqual(qos.value, 0)
-        qos.increment()
+        qos.increment_eventually()
         self.assertEqual(qos.value, 0)
 
     def test_consumer_decrement_eventually(self):
@@ -433,8 +432,12 @@ class test_Consumer(Case):
         l.consume_messages()
         self.assertTrue(l.task_consumer.consume.call_count)
         l.task_consumer.qos.assert_called_with(prefetch_count=10)
-        l.qos.decrement()
-        l.consume_messages()
+        l.task_consumer.qos = Mock()
+        self.assertEqual(l.qos.value, 10)
+        l.qos.decrement_eventually()
+        self.assertEqual(l.qos.value, 9)
+        l.qos.update()
+        self.assertEqual(l.qos.value, 9)
         l.task_consumer.qos.assert_called_with(prefetch_count=9)
 
     def test_maybe_conn_error(self):
@@ -467,6 +470,7 @@ class test_Consumer(Case):
 
         l.task_consumer = Mock()
         l.qos = QoS(l.task_consumer, l.initial_prefetch_count)
+        current_pcount = l.qos.value
         l.event_dispatcher = Mock()
         l.enabled = False
         l.update_strategies()
@@ -479,7 +483,7 @@ class test_Consumer(Case):
             if item.args[0].name == foo_task.name:
                 found = True
         self.assertTrue(found)
-        self.assertTrue(l.task_consumer.qos.call_count)
+        self.assertGreater(l.qos.value, current_pcount)
         l.timer.stop()
 
     def test_on_control(self):

+ 20 - 22
celery/worker/consumer.py

@@ -196,30 +196,22 @@ class QoS(object):
     def __init__(self, consumer, initial_value):
         self.consumer = consumer
         self._mutex = threading.RLock()
-        self.value = initial_value
+        self.value = initial_value or 0
 
-    def increment(self, n=1):
-        """Increment the current prefetch count value by n."""
-        with self._mutex:
-            if self.value:
-                new_value = self.value + max(n, 0)
-                self.value = self.set(new_value)
-        return self.value
+    def increment_eventually(self, n=1):
+        """Increment the value, but do not update the channels QoS.
 
-    def _sub(self, n=1):
-        assert self.value - n > 1
-        self.value -= n
+        The MainThread will be responsible for calling :meth:`update`
+        when necessary.
 
-    def decrement(self, n=1):
-        """Decrement the current prefetch count value by n."""
+        """
         with self._mutex:
             if self.value:
-                self._sub(n)
-                self.set(self.value)
+                self.value = self.value + max(n, 0)
         return self.value
 
     def decrement_eventually(self, n=1):
-        """Decrement the value, but do not update the qos.
+        """Decrement the value, but do not update the channels QoS.
 
         The MainThread will be responsible for calling :meth:`update`
         when necessary.
@@ -227,7 +219,9 @@ class QoS(object):
         """
         with self._mutex:
             if self.value:
-                self._sub(n)
+                self.value -= n
+                print("DECREMENT %r" % (self.value, ))
+        return self.value
 
     def set(self, pcount):
         """Set channel prefetch_count setting."""
@@ -245,6 +239,7 @@ class QoS(object):
     def update(self):
         """Update prefetch count with current value."""
         with self._mutex:
+            print("SET: %r " % (self.value, ))
             return self.set(self.value)
 
 
@@ -405,7 +400,7 @@ class Consumer(object):
                     self.handle_unknown_task(body, message, exc)
                 except InvalidTaskError, exc:
                     self.handle_invalid_task(body, message, exc)
-                fire_timers()
+                #fire_timers()
 
             self.task_consumer.callbacks = [on_task_received]
             self.task_consumer.consume()
@@ -423,13 +418,16 @@ class Consumer(object):
                 # the number of seconds until we need to fire timers again.
                 poll_timeout = fire_timers() if scheduled else 1
 
+                # We only update QoS when there is no more messages to read.
+                # This groups together qos calls, and makes sure that remote
+                # control commands will be prioritized over task messages.
+                if qos.prev != qos.value:
+                    update_qos()
+
                 update_readers(on_poll_start())
                 if readers or writers:
                     connection.more_to_read = True
                     while connection.more_to_read:
-                        if qos.prev != qos.value:
-                            update_qos()
-
                         for fileno, event in poll(poll_timeout) or ():
                             try:
                                 if event & READ:
@@ -485,7 +483,7 @@ class Consumer(object):
                       task.eta, exc, task.info(safe=True), exc_info=True)
                 task.acknowledge()
             else:
-                self.qos.increment()
+                self.qos.increment_eventually()
                 self.timer.apply_at(eta, self.apply_eta_task, (task, ),
                                     priority=6)
         else: