Ver Fonte

Doesn't work properly yet.

Ask Solem há 16 anos atrás
pai
commit
cb7ff4e29e
3 ficheiros alterados com 93 adições e 73 exclusões
  1. 4 1
      celery/execute.py
  2. 0 2
      celery/messaging.py
  3. 89 70
      celery/worker.py

+ 4 - 1
celery/execute.py

@@ -3,6 +3,7 @@ from celery.conf import AMQP_CONNECTION_TIMEOUT
 from celery.result import AsyncResult
 from celery.messaging import TaskPublisher
 from functools import partial as curry
+from datetime import datetime, timedelta
 
 
 def apply_async(task, args=None, kwargs=None, routing_key=None,
@@ -50,6 +51,8 @@ def apply_async(task, args=None, kwargs=None, routing_key=None,
     priority = priority or getattr(task, "priority", None)
     taskset_id = opts.get("taskset_id")
     publisher = opts.get("publisher")
+    if countdown:
+        eta = datetime.now() + timedelta(seconds=countdown)
 
     need_to_close_connection = False
     if not publisher:
@@ -65,7 +68,7 @@ def apply_async(task, args=None, kwargs=None, routing_key=None,
     task_id = delay_task(task.name, args, kwargs,
                          routing_key=routing_key, mandatory=mandatory,
                          immediate=immediate, priority=priority,
-                         countdown=countdown, eta=eta)
+                         eta=eta)
 
     if need_to_close_connection:
         publisher.close()

+ 0 - 2
celery/messaging.py

@@ -43,7 +43,6 @@ class TaskPublisher(Publisher):
             task_args=None, task_kwargs=None, **kwargs):
         """INTERNAL"""
         eta = kwargs.get("eta")
-        countdown = kwargs.get("countdown")
         priority = kwargs.get("priority")
         immediate = kwargs.get("immediate")
         mandatory = kwargs.get("mandatory")
@@ -57,7 +56,6 @@ class TaskPublisher(Publisher):
             "task": task_name,
             "args": task_args,
             "kwargs": task_kwargs,
-            "countdown": countdown,
             "eta": eta,
         }
         if part_of_set:

+ 89 - 70
celery/worker.py

@@ -265,6 +265,75 @@ class TaskWrapper(object):
                 meta={"task_id": self.task_id, "task_name": self.task_name})
 
 
+class AMQPMediator(threading.Thread):
+    """Thread continously taking care of new messages pushed by the
+    AMQP broker."""
+    
+    def __init__(self, bucket_queue, hold_queue):
+        super(AMQPMediator, self).__init__()
+        self._shutdown = threading.Event()
+        self._stopped = threading.Event()
+        self.bucket_queue = bucket_queue
+        self.hold_queue = hold_queue
+        self.amqp_connection = None
+        self.task_consumer = None
+
+    def add_to_queue(self, message_data, message):
+        eta = message_data.get("eta")
+        if eta:
+            print("ADD TO HOLD QUEUE: %s" % eta)
+            self.hold_queue.put((message_data, message, eta))
+        else:
+            self.bucket_queue.put((message_data, message))
+    
+    def run(self):
+        print("THREAD RUNNING")
+        task_consumer = self.reset_connection()
+        it = task_consumer.iterconsume(limit=None)
+        while True:
+            if self._shutdown.isSet():
+                break
+            print("TRYING TO GET NEXT MESSAGE")
+            it.next()
+        self.close_connection()
+        self._stopped.set() # indicate that we are stopped
+
+    def stop(self):
+        """Shutdown the thread."""
+        self._shutdown.set()
+        self._stopped.wait() # block until this thread is done
+
+    def close_connection(self):
+        """Close the AMQP connection."""
+        if self.task_consumer:
+            self.task_consumer.close()
+        if self.amqp_connection:
+            self.amqp_connection.close()
+
+    def reset_connection(self):
+        """Reset the AMQP connection, and reinitialize the
+        :class:`celery.messaging.TaskConsumer` instance.
+
+        Resets the task consumer in :attr:`task_consumer`.
+
+        """
+        self.close_connection()
+        self.amqp_connection = DjangoAMQPConnection()
+        self.task_consumer = TaskConsumer(connection=self.amqp_connection)
+        self.task_consumer.register_callback(self.add_to_queue)
+        return self.task_consumer
+
+    def connection_diagnostics(self):
+        """Diagnose the AMQP connection, and reset connection if
+        necessary."""
+        connection = self.task_consumer.backend.channel.connection
+
+        if not connection:
+            self.logger.info(
+                    "AMQP Connection has died, restoring connection.")
+            self.reset_connection()
+
+
 class PeriodicWorkController(threading.Thread):
     """A thread that continuously checks if there are
     :class:`celery.task.PeriodicTask` tasks waiting for execution,
@@ -347,41 +416,11 @@ class WorkController(object):
         self.pool = TaskPool(self.concurrency, logger=self.logger)
         self.periodicworkcontroller = PeriodicWorkController()
         self.is_detached = is_detached
-        self.amqp_connection = None
-        self.task_consumer = None
-        self.bucket_queue = Queue()
+        self.bucket_queue = Queue(maxsize=self.concurrency)
+        self.hold_queue = Queue()
+        self.amqp_mediator = AMQPMediator(self.bucket_queue, self.hold_queue)
 
-    def close_connection(self):
-        """Close the AMQP connection."""
-        if self.task_consumer:
-            self.task_consumer.close()
-        if self.amqp_connection:
-            self.amqp_connection.close()
-
-    def reset_connection(self):
-        """Reset the AMQP connection, and reinitialize the
-        :class:`celery.messaging.TaskConsumer` instance.
-
-        Resets the task consumer in :attr:`task_consumer`.
-
-        """
-        self.close_connection()
-        self.amqp_connection = DjangoAMQPConnection()
-        self.task_consumer = TaskConsumer(connection=self.amqp_connection)
-        self.task_consumer.register_callback(self._message_callback)
-        return self.task_consumer
-
-    def connection_diagnostics(self):
-        """Diagnose the AMQP connection, and reset connection if
-        necessary."""
-        connection = self.task_consumer.backend.channel.connection
-
-        if not connection:
-            self.logger.info(
-                    "AMQP Connection has died, restoring connection.")
-            self.reset_connection()
-
-    def _message_callback(self, message_data, message):
+    def safe_process_task(self, message_data, message):
         """The method called when we receive a message."""
         try:
             try:
@@ -401,14 +440,6 @@ class WorkController(object):
     def process_task(self, message_data, message):
         """Process task message by passing it to the pool of workers."""
         
-        countdown = message_data.get("countdown")
-        eta = message_data.get("eta")
-        if countdown:
-            eta = datetime.now() + timedelta(seconds=int(countdown))
-        if eta:
-            self.bucket_queue.put((message, message_data, eta))
-            return
-
         task = TaskWrapper.from_message(message, message_data,
                                         logger=self.logger)
         self.logger.info("Got task from broker: %s[%s]" % (
@@ -428,6 +459,7 @@ class WorkController(object):
         if self._state != "RUN":
             return
         self._state = "TERMINATE"
+        self.amqp_mediator.stop()
         self.periodicworkcontroller.stop()
         self.pool.terminate()
         self.close_connection()
@@ -435,47 +467,34 @@ class WorkController(object):
     def run(self):
         """Starts the workers main loop."""
         self._state = "RUN"
-        task_consumer = self.reset_connection()
-        it = task_consumer.iterconsume(limit=None)
 
         self.pool.run()
+        self.amqp_mediator.start()
         self.periodicworkcontroller.start()
 
-        # If not running as daemon, and DEBUG logging level is enabled,
-        # print pool PIDs and sleep for a second before we start.
-        if self.logger.isEnabledFor(logging.DEBUG):
-            self.logger.debug("Pool child processes: [%s]" % (
-                "|".join(map(str, self.pool.get_worker_pids()))))
-            if not self.is_detached:
-                time.sleep(1)
-
         try:
             while True:
-                try:
-                    self.process_bucket()
-                    self.process_next(it, timeout=1)
-                except TimeoutError:
-                    pass
+                self.process_hold()
+                self.process_bucket()
         except (SystemExit, KeyboardInterrupt):
             self.shutdown()
 
-    def process_next(self, it, timeout=1):
-        def on_timeout():
-            raise TimeoutError()
-        timer = threading.Timer(timeout, on_timeout)
-        timer.start()
+    def process_hold(self):
         try:
-            it.next()
-        finally:
-            timer.cancel()
+            message_data, message, eta = self.hold_queue.get(timeout=0.2)
+        except QueueEmpty:
+            pass
+        else:
+            print("GOT ITEM FROM HOLD QUEUE: %s" % eta)
+            if datetime.now() >= eta:
+                self.safe_process_task(message_data, message)
+            else:
+                self.hold_queue.put((message_data, message, eta))
 
     def process_bucket(self):
         try:
-            message, msg_data, eta = self.bucket_queue.get_nowait()
+            message_data, message = self.bucket_queue.get(timeout=0.2)
         except QueueEmpty:
             pass
         else:
-            if datetime.now() >= eta:
-                self.process_task(message, msg_data)
-            else:
-                self.bucket_queue.put((message, msg_data, eta))
+            self.safe_process_task(message_data, message)