Browse Source

Implement AMQP basic.qos (Quality of Service) to set the message prefetch
window. Closes #27

Ask Solem 15 years ago
parent
commit
c5b31df850

+ 70 - 0
celery/datastructures.py

@@ -4,6 +4,8 @@ Custom Datastructures
 
 """
 from UserList import UserList
+from Queue import Queue
+from Queue import Empty as QueueEmpty
 import traceback
 
 
@@ -70,3 +72,71 @@ class ExceptionInfo(object):
 
     def __str__(self):
         return str(self.exception)
+
+
+def consume_queue(queue):
+    while True:
+        try:
+            yield queue.get_nowait()
+        except QueueEmpty, exc:
+            raise StopIteration()
+
+
+class SharedCounter(object):
+    """An integer that can be updated by several threads at once.
+
+    Please note that the final value is not synchronized, this means
+    that you should not update the value on a previous value, the only
+    reliable operations are increment and decrement.
+
+    Example
+
+        >>> max_clients = SharedCounter(initial_value=10)
+
+        # Thread one
+        >>> max_clients += 1 # OK (safe)
+
+        # Thread two
+        >>> max_clients -= 3 # OK (safe)
+
+        # Main thread
+        >>> if client >= int(max_clients): # Max clients now at 8
+        ...    wait()
+
+
+        >>> max_client = max_clients + 10 # NOT OK (unsafe)
+
+    """
+    def __init__(self, initial_value):
+        self._value = initial_value
+        self._modify_queue = Queue()
+
+    def increment(self):
+        """Increment the value by one."""
+        self += 1
+
+    def decrement(self):
+        """Decrement the value by one."""
+        self -= 1
+        
+    def _update_value(self):
+        self._value += sum(consume_queue(self._modify_queue))
+        return self._value
+    
+    def __iadd__(self, y):
+        """``self += y``"""
+        self._modify_queue.put(y * +1)
+        return self
+
+    def __isub__(self, y):
+        """``self -= y``"""
+        self._modify_queue.put(y * -1)
+        return self
+
+    def __int__(self):
+        """``int(self) -> int``"""
+        self._update_value()
+        return self._value
+
+    def __repr__(self):
+        return "<SharedCounter: int(%s)>" % str(int(self))

+ 3 - 2
celery/tests/test_worker.py

@@ -136,10 +136,11 @@ class TestAMQPListener(unittest.TestCase):
         l.receive_message(m.decode(), m)
 
         in_hold = self.hold_queue.get_nowait()
-        self.assertEquals(len(in_hold), 2)
-        task, eta = in_hold
+        self.assertEquals(len(in_hold), 3)
+        task, eta, on_accept = in_hold
         self.assertTrue(isinstance(task, TaskWrapper))
         self.assertTrue(isinstance(eta, datetime))
+        self.assertTrue(callable(on_accept))
         self.assertEquals(task.task_name, "c.u.foo")
         self.assertEquals(task.execute(), 2 * 4 * 8)
         self.assertRaises(Empty, self.bucket_queue.get_nowait)

+ 9 - 4
celery/tests/test_worker_controllers.py

@@ -79,20 +79,25 @@ class TestPeriodicWorkController(unittest.TestCase):
         bucket_queue = Queue()
         hold_queue = Queue()
         m = PeriodicWorkController(bucket_queue, hold_queue)
-
         m.process_hold_queue()
 
+        scratchpad = {}
+        def on_accept():
+            scratchpad["accepted"] = True
+
         hold_queue.put((MockTask("task1"),
-                        datetime.now() - timedelta(days=1)))
+                        datetime.now() - timedelta(days=1),
+                        on_accept))
 
         m.process_hold_queue()
         self.assertRaises(Empty, hold_queue.get_nowait)
+        self.assertTrue(scratchpad.get("accepted"))
         self.assertEquals(bucket_queue.get_nowait().value, "task1")
         tomorrow = datetime.now() + timedelta(days=1)
-        hold_queue.put((MockTask("task2"), tomorrow))
+        hold_queue.put((MockTask("task2"), tomorrow, on_accept))
         m.process_hold_queue()
         self.assertRaises(Empty, bucket_queue.get_nowait)
-        value, eta = hold_queue.get_nowait()
+        value, eta, on_accept = hold_queue.get_nowait()
         self.assertEquals(value.value, "task2")
         self.assertEquals(eta, tomorrow)
 

+ 11 - 3
celery/worker/__init__.py

@@ -15,6 +15,7 @@ from celery.conf import AMQP_CONNECTION_RETRY, AMQP_CONNECTION_MAX_RETRIES
 from celery.log import setup_logger
 from celery.pool import TaskPool
 from celery.utils import retry_over_time
+from celery.datastructures import SharedCounter
 from Queue import Queue
 import traceback
 import logging
@@ -43,12 +44,14 @@ class AMQPListener(object):
 
     """
 
-    def __init__(self, bucket_queue, hold_queue, logger):
+    def __init__(self, bucket_queue, hold_queue, logger,
+            initial_prefetch_count=2):
         self.amqp_connection = None
         self.task_consumer = None
         self.bucket_queue = bucket_queue
         self.hold_queue = hold_queue
         self.logger = logger
+        self.prefetch_count = SharedCounter(initial_prefetch_count)
 
     def start(self):
         """Start the consumer.
@@ -77,6 +80,7 @@ class AMQPListener(object):
         self.logger.debug("AMQPListener: Ready to accept tasks!")
 
         while True:
+            self.task_consumer.qos(prefetch_count=int(self.prefetch_count))
             it.next()
 
     def stop(self):
@@ -102,12 +106,15 @@ class AMQPListener(object):
         if eta:
             self.logger.info("Got task from broker: %s[%s] eta:[%s]" % (
                     task.task_name, task.task_id, eta))
-            self.hold_queue.put((task, eta))
+            self.hold_queue.put((task, eta, self.prefetch_count.decrement))
+            self.prefetch_count.increment()
         else:
+            self.prefetch_count.decrement()
             self.logger.info("Got task from broker: %s[%s]" % (
                     task.task_name, task.task_id))
             self.bucket_queue.put(task)
 
+
     def close_connection(self):
         """Close the AMQP connection."""
         if self.task_consumer:
@@ -246,7 +253,8 @@ class WorkController(object):
                                                     self.hold_queue)
         self.pool = TaskPool(self.concurrency, logger=self.logger)
         self.amqp_listener = AMQPListener(self.bucket_queue, self.hold_queue,
-                                          logger=self.logger)
+                                          logger=self.logger,
+                                          initial_prefetch_count=concurrency)
         self.mediator = Mediator(self.bucket_queue, self.safe_process_task)
 
         # The order is important here;

+ 4 - 2
celery/worker/controllers.py

@@ -147,17 +147,19 @@ class PeriodicWorkController(BackgroundThread):
         try:
             logger.debug(
                 "PeriodicWorkController: Getting next task from hold queue..")
-            task, eta = self.hold_queue.get_nowait()
+            task, eta, on_accept = self.hold_queue.get_nowait()
         except QueueEmpty:
             logger.debug("PeriodicWorkController: Hold queue is empty")
             return
+
         if datetime.now() >= eta:
             logger.debug(
                 "PeriodicWorkController: Time to run %s[%s] (%s)..." % (
                     task.task_name, task.task_id, eta))
+            on_accept() # Run the accept task callback.
             self.bucket_queue.put(task)
         else:
             logger.debug(
                 "PeriodicWorkController: ETA not ready for %s[%s] (%s)..." % (
                     task.task_name, task.task_id, eta))
-            self.hold_queue.put((task, eta))
+            self.hold_queue.put((task, eta, on_accept))