Browse Source

Can't set QoS in another thread, so decrement after the next received channel event.

Ask Solem 15 years ago
parent
commit
0bae96d1ee
4 changed files with 54 additions and 25 deletions
  1. 11 1
      celery/tests/test_worker.py
  2. 2 2
      celery/worker/controllers.py
  3. 39 21
      celery/worker/listener.py
  4. 2 1
      celery/worker/scheduler.py

+ 11 - 1
celery/tests/test_worker.py

@@ -12,7 +12,7 @@ from celery.utils import gen_unique_id
 from celery.worker import WorkController
 from celery.worker import WorkController
 from celery.worker.job import TaskWrapper
 from celery.worker.job import TaskWrapper
 from celery.worker.buckets import FastQueue
 from celery.worker.buckets import FastQueue
-from celery.worker.listener import CarrotListener, RUN
+from celery.worker.listener import CarrotListener, QoS, RUN
 from celery.worker.scheduler import Scheduler
 from celery.worker.scheduler import Scheduler
 from celery.decorators import task as task_dec
 from celery.decorators import task as task_dec
 from celery.decorators import periodic_task as periodic_task_dec
 from celery.decorators import periodic_task as periodic_task_dec
@@ -249,6 +249,13 @@ class TestCarrotListener(unittest.TestCase):
         self.assertTrue(self.eta_schedule.empty())
         self.assertTrue(self.eta_schedule.empty())
 
 
     def test_receieve_message_eta_isoformat(self):
     def test_receieve_message_eta_isoformat(self):
+
+        class MockConsumer(object):
+            prefetch_count_incremented = False
+
+            def qos(self, **kwargs):
+                self.prefetch_count_incremented = True
+
         l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
         l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
                            send_events=False)
                            send_events=False)
         backend = MockBackend()
         backend = MockBackend()
@@ -257,6 +264,8 @@ class TestCarrotListener(unittest.TestCase):
                            args=[2, 4, 8], kwargs={})
                            args=[2, 4, 8], kwargs={})
 
 
         l.event_dispatcher = MockEventDispatcher()
         l.event_dispatcher = MockEventDispatcher()
+        l.task_consumer = MockConsumer()
+        l.qos = QoS(l.task_consumer, l.initial_prefetch_count, l.logger)
         l.receive_message(m.decode(), m)
         l.receive_message(m.decode(), m)
 
 
         items = [entry[2] for entry in self.eta_schedule.queue]
         items = [entry[2] for entry in self.eta_schedule.queue]
@@ -265,6 +274,7 @@ class TestCarrotListener(unittest.TestCase):
             if item.task_name == foo_task.name:
             if item.task_name == foo_task.name:
                 found = True
                 found = True
         self.assertTrue(found)
         self.assertTrue(found)
+        self.assertTrue(l.task_consumer.prefetch_count_incremented)
 
 
     def test_revoke(self):
     def test_revoke(self):
         ready_queue = FastQueue()
         ready_queue = FastQueue()

+ 2 - 2
celery/worker/controllers.py

@@ -117,6 +117,6 @@ class ScheduleController(BackgroundThread):
         if delay is None:
         if delay is None:
             delay = 1
             delay = 1
 
 
-        self.debug("ScheduleController: Scheduler wake-up",
-              "ScheduleController: Next wake-up eta %s seconds..." % delay)
+        self.debug("ScheduleController: Scheduler wake-up"
+                "ScheduleController: Next wake-up eta %s seconds..." % delay)
         time.sleep(delay)
         time.sleep(delay)

+ 39 - 21
celery/worker/listener.py

@@ -23,6 +23,39 @@ RUN = 0x0
 CLOSE = 0x1
 CLOSE = 0x1
 
 
 
 
+class QoS(object):
+    prev = None
+
+    def __init__(self, consumer, initial_value, logger):
+        self.consumer = consumer
+        self.logger = logger
+        self.value = SharedCounter(initial_value)
+
+        self.set(int(self.value))
+
+    def increment(self):
+        return self.set(self.value.increment())
+
+    def decrement(self):
+        return self.set(self.value.decrement())
+
+    def decrement_eventually(self):
+        self.value.decrement()
+
+    def set(self, pcount):
+        self.logger.debug("basic.qos: prefetch_count->%s" % pcount)
+        self.consumer.qos(prefetch_count=pcount)
+        self.prev = pcount
+        return pcount
+
+    def update(self):
+        return self.set(self.next)
+
+    @property
+    def next(self):
+        return int(self.value)
+
+
 class CarrotListener(object):
 class CarrotListener(object):
     """Listen for messages received from the broker and
     """Listen for messages received from the broker and
     move them the the ready queue for task processing.
     move them the the ready queue for task processing.
@@ -60,8 +93,6 @@ class CarrotListener(object):
         self.control_dispatch = ControlDispatch(logger=logger,
         self.control_dispatch = ControlDispatch(logger=logger,
                                                 hostname=self.hostname,
                                                 hostname=self.hostname,
                                                 listener=self)
                                                 listener=self)
-        self.prefetch_count = SharedCounter(self.initial_prefetch_count)
-        self.prev_pcount = None
         self.event_dispatcher = None
         self.event_dispatcher = None
         self.heart = None
         self.heart = None
         self._state = None
         self._state = None
@@ -93,22 +124,10 @@ class CarrotListener(object):
         self.logger.debug("CarrotListener: Ready to accept tasks!")
         self.logger.debug("CarrotListener: Ready to accept tasks!")
 
 
         while 1:
         while 1:
-            pcount = int(self.prefetch_count) # SharedCounter() -> int()
-            if not self.prev_pcount or pcount != self.prev_pcount:
-                self.update_task_qos(pcount)
+            if self.qos.prev != self.qos.next:
+                self.qos.update()
             wait_for_message()
             wait_for_message()
 
 
-    def task_qos_increment(self):
-        self.update_task_qos(self.prefetch_count.increment())
-
-    def task_qos_decrement(self):
-        self.update_task_qos(self.prefetch_count.decrement())
-
-    def update_task_qos(self, pcount):
-        self.logger.debug("basic.qos: prefetch_count->%s" % pcount)
-        self.task_consumer.qos(prefetch_count=pcount)
-        self.prev_pcount = pcount
-
     def on_task(self, task, eta=None):
     def on_task(self, task, eta=None):
         """Handle received task.
         """Handle received task.
 
 
@@ -129,11 +148,11 @@ class CarrotListener(object):
         if eta:
         if eta:
             if not isinstance(eta, datetime):
             if not isinstance(eta, datetime):
                 eta = parse_iso8601(eta)
                 eta = parse_iso8601(eta)
-            self.task_qos_increment()
+            self.qos.increment()
             self.logger.info("Got task from broker: %s[%s] eta:[%s]" % (
             self.logger.info("Got task from broker: %s[%s] eta:[%s]" % (
                     task.task_name, task.task_id, eta))
                     task.task_name, task.task_id, eta))
             self.eta_schedule.enter(task, eta=eta,
             self.eta_schedule.enter(task, eta=eta,
-                                    callback=self.task_qos_decrement)
+                    callback=self.qos.decrement_eventually)
         else:
         else:
             self.logger.info("Got task from broker: %s[%s]" % (
             self.logger.info("Got task from broker: %s[%s]" % (
                     task.task_name, task.task_id))
                     task.task_name, task.task_id))
@@ -221,13 +240,12 @@ class CarrotListener(object):
         self.ready_queue.clear()
         self.ready_queue.clear()
         self.eta_schedule.clear()
         self.eta_schedule.clear()
 
 
-
         self.connection = self._open_connection()
         self.connection = self._open_connection()
         self.logger.debug("CarrotListener: Connection Established.")
         self.logger.debug("CarrotListener: Connection Established.")
         self.task_consumer = get_consumer_set(connection=self.connection)
         self.task_consumer = get_consumer_set(connection=self.connection)
         # QoS: Reset prefetch window.
         # QoS: Reset prefetch window.
-        self.prefetch_count = SharedCounter(self.initial_prefetch_count)
-        self.update_task_qos(int(self.prefetch_count))
+        self.qos = QoS(self.task_consumer,
+                       self.initial_prefetch_count, self.logger)
 
 
         self.task_consumer.on_decode_error = self.on_decode_error
         self.task_consumer.on_decode_error = self.on_decode_error
         self.broadcast_consumer = BroadcastConsumer(self.connection,
         self.broadcast_consumer = BroadcastConsumer(self.connection,

+ 2 - 1
celery/worker/scheduler.py

@@ -69,7 +69,8 @@ class Scheduler(object):
 
 
                     if event is verify:
                     if event is verify:
                         ready_queue.put(item)
                         ready_queue.put(item)
-                        callback and callback()
+                        if callback is not None:
+                            callback()
                         yield 0
                         yield 0
                     else:
                     else:
                         heapq.heappush(self._queue, event)
                         heapq.heappush(self._queue, event)