فهرست منبع

QOS: Make sure QoS is enabled before consuming. Also qos update happens at once with eta tasks.

Ask Solem 15 سال پیش
والد
کامیت
b7e5c93865
2فایلهای تغییر یافته به همراه19 افزوده شده و 8 حذف شده
  1. 2 0
      celery/datastructures.py
  2. 17 8
      celery/worker/listener.py

+ 2 - 0
celery/datastructures.py

@@ -140,10 +140,12 @@ class SharedCounter(object):
     def increment(self, n=1):
     def increment(self, n=1):
         """Increment value."""
         """Increment value."""
         self += n
         self += n
+        return int(self)
 
 
     def decrement(self, n=1):
     def decrement(self, n=1):
         """Decrement value."""
         """Decrement value."""
         self -= n
         self -= n
+        return int(self)
 
 
     def _update_value(self):
     def _update_value(self):
         self._value += sum(consume_queue(self._modify_queue))
         self._value += sum(consume_queue(self._modify_queue))

+ 17 - 8
celery/worker/listener.py

@@ -95,11 +95,19 @@ class CarrotListener(object):
         while 1:
         while 1:
             pcount = int(self.prefetch_count) # SharedCounter() -> int()
             pcount = int(self.prefetch_count) # SharedCounter() -> int()
             if not self.prev_pcount or pcount != self.prev_pcount:
             if not self.prev_pcount or pcount != self.prev_pcount:
-                self.logger.debug("basic.qos: prefetch_count->%s" % pcount)
-                task_consumer.qos(prefetch_count=pcount)
-                self.prev_pcount = pcount
+                self.update_task_qos(pcount)
             wait_for_message()
             wait_for_message()
 
 
+    def task_qos_increment(self):
+        self.update_task_qos(self.prefetch_count.increment())
+
+    def task_qos_decrement(self):
+        self.update_task_qos(self.prefetch_count.decrement())
+
+    def update_task_qos(self, pcount):
+        self.logger.debug("basic.qos: prefetch_count->%s" % pcount)
+        self.task_consumer.qos(prefetch_count=pcount)
+        self.prev_pcount = pcount
 
 
     def on_task(self, task, eta=None):
     def on_task(self, task, eta=None):
         """Handle received task.
         """Handle received task.
@@ -121,11 +129,11 @@ class CarrotListener(object):
         if eta:
         if eta:
             if not isinstance(eta, datetime):
             if not isinstance(eta, datetime):
                 eta = parse_iso8601(eta)
                 eta = parse_iso8601(eta)
-            self.prefetch_count.increment()
+            self.task_qos_increment()
             self.logger.info("Got task from broker: %s[%s] eta:[%s]" % (
             self.logger.info("Got task from broker: %s[%s] eta:[%s]" % (
                     task.task_name, task.task_id, eta))
                     task.task_name, task.task_id, eta))
             self.eta_schedule.enter(task, eta=eta,
             self.eta_schedule.enter(task, eta=eta,
-                                    callback=self.prefetch_count.decrement)
+                                    callback=self.task_qos_decrement)
         else:
         else:
             self.logger.info("Got task from broker: %s[%s]" % (
             self.logger.info("Got task from broker: %s[%s]" % (
                     task.task_name, task.task_id))
                     task.task_name, task.task_id))
@@ -213,13 +221,14 @@ class CarrotListener(object):
         self.ready_queue.clear()
         self.ready_queue.clear()
         self.eta_schedule.clear()
         self.eta_schedule.clear()
 
 
-        # Reset prefetch window.
-        self.prefetch_count = SharedCounter(self.initial_prefetch_count)
-        self.prev_pcount = None
 
 
         self.connection = self._open_connection()
         self.connection = self._open_connection()
         self.logger.debug("CarrotListener: Connection Established.")
         self.logger.debug("CarrotListener: Connection Established.")
         self.task_consumer = get_consumer_set(connection=self.connection)
         self.task_consumer = get_consumer_set(connection=self.connection)
+        # QoS: Reset prefetch window.
+        self.prefetch_count = SharedCounter(self.initial_prefetch_count)
+        self.update_task_qos(int(self.prefetch_count))
+
         self.task_consumer.on_decode_error = self.on_decode_error
         self.task_consumer.on_decode_error = self.on_decode_error
         self.broadcast_consumer = BroadcastConsumer(self.connection,
         self.broadcast_consumer = BroadcastConsumer(self.connection,
                                                     hostname=self.hostname)
                                                     hostname=self.hostname)