Browse Source

Pidbox improvements

Ask Solem 14 years ago
parent
commit
22e53f9f9c
6 changed files with 137 additions and 148 deletions
  1. 4 2
      celery/app/amqp.py
  2. 109 95
      celery/pidbox.py
  3. 1 1
      celery/task/base.py
  4. 4 31
      celery/task/control.py
  5. 5 8
      celery/worker/control/__init__.py
  6. 14 11
      celery/worker/listener.py

+ 4 - 2
celery/app/amqp.py

@@ -143,12 +143,14 @@ class AMQP(object):
     BrokerConnection = BrokerConnection
     Publisher = messaging.Publisher
     Consumer = messaging.Consumer
-    ConsumerSet = messaging.ConsumerSet
     _queues = None
 
     def __init__(self, app):
         self.app = app
 
+    def ConsumerSet(self, *args, **kwargs):
+        return messaging.ConsumerSet(*args, **kwargs)
+
     def Queues(self, queues):
         return Queues.with_defaults(queues,
                                     self.app.conf.CELERY_DEFAULT_EXCHANGE,
@@ -203,7 +205,7 @@ class AMQP(object):
         transport = transport or "amqp"
 
         port = broker_connection.port or \
-                    broker_connection.get_backend_cls().default_port
+                    broker_connection.get_transport_cls().default_port
         port = port and ":%s" % port or ""
 
         vhost = broker_connection.virtual_host

+ 109 - 95
celery/pidbox.py

@@ -3,28 +3,56 @@ import warnings
 
 from itertools import count
 
-from carrot.messaging import Consumer, Publisher
+from kombu.entity import Exchange, Queue
+from kombu.messaging import Consumer, Producer
 
 from celery.app import app_or_default
-
-
-class ControlReplyConsumer(Consumer):
-    exchange = "celerycrq"
-    exchange_type = "direct"
-    durable = False
-    exclusive = False
-    auto_delete = True
-    no_ack = True
-
-    def __init__(self, connection, ticket, **kwargs):
-        self.ticket = ticket
-        queue = "%s.%s" % (self.exchange, ticket)
-        super(ControlReplyConsumer, self).__init__(connection,
-                                                   queue=queue,
-                                                   routing_key=ticket,
-                                                   **kwargs)
-
-    def collect(self, limit=None, timeout=1, callback=None):
+from celery.utils import gen_unique_id
+
+
+
+class Mailbox(object):
+
+    def __init__(self, namespace, connection):
+        self.namespace = namespace
+        self.connection = connection
+        self.exchange = Exchange("%s.pidbox" % (self.namespace, ),
+                                 type="fanout",
+                                 durable=False,
+                                 auto_delete=True)
+        self.reply_exchange = Exchange("reply.%s.pidbox" % (self.namespace, ),
+                                 type="direct",
+                                 durable=False,
+                                 auto_delete=True)
+
+    def publish_reply(self, reply, exchange, routing_key, channel=None):
+        chan = channel or self.connection.channel()
+        try:
+            exchange = Exchange(exchange, exchange_type="direct",
+                                          delivery_mode="transient",
+                                          durable=False,
+                                          auto_delete=True)
+            producer = Producer(chan, exchange=exchange)
+            producer.publish(reply, routing_key=routing_key)
+        finally:
+            channel or chan.close()
+
+    def get_reply_queue(self, ticket):
+        return Queue("%s.%s" % (ticket, self.reply_exchange.name),
+                     exchange=self.reply_exchange,
+                     routing_key=ticket,
+                     durable=False,
+                     auto_delete=True)
+
+    def get_queue(self, hostname):
+        return Queue("%s.%s.pidbox" % (hostname, self.namespace),
+                     exchange=self.exchange)
+
+    def collect_reply(self, ticket, limit=None, timeout=1,
+            callback=None, channel=None):
+        chan = channel or self.connection.channel()
+        queue = self.get_reply_queue(ticket)
+        consumer = Consumer(channel, [queue], no_ack=True)
         responses = []
 
         def on_message(message_data, message):
@@ -32,82 +60,68 @@ class ControlReplyConsumer(Consumer):
                 callback(message_data)
             responses.append(message_data)
 
-        self.callbacks = [on_message]
-        self.consume()
-        for i in limit and range(limit) or count():
-            try:
-                self.connection.drain_events(timeout=timeout)
-            except socket.timeout:
-                break
-
-        return responses
-
-
-class ControlReplyPublisher(Publisher):
-    exchange = "celerycrq"
-    exchange_type = "direct"
-    delivery_mode = "non-persistent"
-    durable = False
-    auto_delete = True
-
-
-class BroadcastPublisher(Publisher):
-    """Publish broadcast commands"""
-
-    ReplyTo = ControlReplyConsumer
-
-    def __init__(self, *args, **kwargs):
-        app = self.app = app_or_default(kwargs.get("app"))
-        kwargs["exchange"] = kwargs.get("exchange") or \
-                                app.conf.CELERY_BROADCAST_EXCHANGE
-        kwargs["exchange_type"] = kwargs.get("exchange_type") or \
-                                app.conf.CELERY_BROADCAST_EXCHANGE_TYPE
-        super(BroadcastPublisher, self).__init__(*args, **kwargs)
-
-    def send(self, type, arguments, destination=None, reply_ticket=None):
-        """Send broadcast command."""
+        try:
+            consumer.register_callback(on_message)
+            consumer.consume()
+            for i in limit and range(limit) or count():
+                try:
+                    self.connection.drain_events(timeout=timeout)
+                except socket.timeout:
+                    break
+            return responses
+        finally:
+            channel or chan.close()
+
+    def publish(self, type, arguments, destination=None, reply_ticket=None,
+            channel=None):
         arguments["command"] = type
         arguments["destination"] = destination
-        reply_to = self.ReplyTo(self.connection, None, app=self.app,
-                                auto_declare=False)
         if reply_ticket:
-            arguments["reply_to"] = {"exchange": reply_to.exchange,
+            arguments["reply_to"] = {"exchange": self.reply_exchange.name,
                                      "routing_key": reply_ticket}
-        super(BroadcastPublisher, self).send({"control": arguments})
-
-
-class BroadcastConsumer(Consumer):
-    """Consume broadcast commands"""
-    no_ack = True
-
-    def __init__(self, *args, **kwargs):
-        self.app = app = app_or_default(kwargs.get("app"))
-        kwargs["queue"] = kwargs.get("queue") or \
-                            app.conf.CELERY_BROADCAST_QUEUE
-        kwargs["exchange"] = kwargs.get("exchange") or \
-                            app.conf.CELERY_BROADCAST_EXCHANGE
-        kwargs["exchange_type"] = kwargs.get("exchange_type") or \
-                            app.conf.CELERY_BROADCAST_EXCHANGE_TYPE
-        self.hostname = kwargs.pop("hostname", None) or socket.gethostname()
-        self.queue = "%s_%s" % (self.queue, self.hostname)
-        super(BroadcastConsumer, self).__init__(*args, **kwargs)
-
-    def verify_exclusive(self):
-        # XXX Kombu material
-        channel = getattr(self.backend, "channel")
-        if channel and hasattr(channel, "queue_declare"):
-            try:
-                _, _, consumers = channel.queue_declare(self.queue,
-                                                        passive=True)
-            except ValueError:
-                pass
-            else:
-                if consumers:
-                    warnings.warn(UserWarning(
-                        "A node named %s is already using this process "
-                        "mailbox. Maybe you should specify a custom name "
-                        "for this node with the -n argument?" % self.hostname))
-
-    def consume(self, *args, **kwargs):
-        self.verify_exclusive()
-        return super(BroadcastConsumer, self).consume(*args, **kwargs)
+        chan = channel or self.connection.channel()
+        producer = Producer(exchange=self.exchange, delivery_mode="transient")
+        try:
+            producer.publish({"control": arguments})
+        finally:
+            channel or chan.close()
+
+    def get_consumer(self, hostname, channel=None):
+        return Consumer(channel or self.connection.channel(),
+                        [self.get_queue(hostname)],
+                        no_ack=True)
+
+    def broadcast(self, command, arguments=None, destination=None,
+            reply=False, timeout=1, limit=None, callback=None, channel=None):
+        arguments = arguments or {}
+        reply_ticket = reply and gen_unique_id() or None
+
+        if destination is not None and \
+                not isinstance(destination, (list, tuple)):
+            raise ValueError("destination must be a list/tuple not %s" % (
+                    type(destination)))
+
+        # Set reply limit to number of destinations (if specificed)
+        if limit is None and destination:
+            limit = destination and len(destination) or None
+
+        chan = channel or self.connection.channel()
+        try:
+            if reply_ticket:
+                self.get_reply_queue(reply_ticket)(chan).declare()
+
+            self.publish(command, arguments, destination=destination,
+                                             reply_ticket=reply_ticket,
+                                             channel=chan)
+
+            if reply_ticket:
+                return self.collect_reply(reply_ticket, limit=limit,
+                                                        timeout=timeout,
+                                                        callback=callback,
+                                                        channel=chan)
+        finally:
+            channel or chan.close()
+
+
+def mailbox(connection):
+    return Mailbox("celeryd", connection)

+ 1 - 1
celery/task/base.py

@@ -541,7 +541,7 @@ class BaseTask(object):
         if kwargs is None:
             kwargs = request.kwargs
 
-        delivery_info = request.delivery_info
+        delivery_info = request.delivery_info or {}
         options.setdefault("exchange", delivery_info.get("exchange"))
         options.setdefault("routing_key", delivery_info.get("routing_key"))
 

+ 4 - 31
celery/task/control.py

@@ -1,5 +1,5 @@
 from celery.app import app_or_default
-from celery.pidbox import BroadcastPublisher, ControlReplyConsumer
+from celery.pidbox import mailbox
 from celery.utils import gen_unique_id
 
 
@@ -184,37 +184,10 @@ class Control(object):
             received.
 
         """
-        arguments = arguments or {}
-        reply_ticket = reply and gen_unique_id() or None
-
-        if destination is not None and \
-                not isinstance(destination, (list, tuple)):
-            raise ValueError("destination must be a list/tuple not %s" % (
-                    type(destination)))
-
-        # Set reply limit to number of destinations (if specificed)
-        if limit is None and destination:
-            limit = destination and len(destination) or None
-
         def _do_broadcast(connection=None, connect_timeout=None):
-
-            crq = None
-            if reply_ticket:
-                crq = ControlReplyConsumer(connection, reply_ticket)
-
-            broadcaster = BroadcastPublisher(connection, app=self.app)
-            try:
-                broadcaster.send(command, arguments, destination=destination,
-                                 reply_ticket=reply_ticket)
-            finally:
-                broadcaster.close()
-
-            if crq:
-                try:
-                    return crq.collect(limit=limit, timeout=timeout,
-                                       callback=callback)
-                finally:
-                    crq.close()
+            return mailbox(connection).broadcast(command, arguments,
+                                                 destination, reply,
+                                                 timeout, limit, callback)
 
         return self.app.with_default_connection(_do_broadcast)(
                 connection=connection, connect_timeout=connect_timeout)

+ 5 - 8
celery/worker/control/__init__.py

@@ -1,5 +1,7 @@
+import socket
+
 from celery.app import app_or_default
-from celery.pidbox import ControlReplyPublisher
+from celery.pidbox import mailbox
 from celery.utils import kwdict
 from celery.worker.control.registry import Panel
 
@@ -9,12 +11,11 @@ __import__("celery.worker.control.builtins")
 class ControlDispatch(object):
     """Execute worker control panel commands."""
     Panel = Panel
-    ReplyPublisher = ControlReplyPublisher
 
     def __init__(self, logger=None, hostname=None, listener=None, app=None):
         self.app = app_or_default(app)
         self.logger = logger or self.app.log.get_default_logger()
-        self.hostname = hostname
+        self.hostname = hostname or socket.gethostname()
         self.listener = listener
         self.panel = self.Panel(self.logger, self.listener, self.hostname,
                                 app=self.app)
@@ -22,11 +23,7 @@ class ControlDispatch(object):
     def reply(self, data, exchange, routing_key, **kwargs):
 
         def _do_reply(connection=None, connect_timeout=None):
-            crq = self.ReplyPublisher(connection, exchange=exchange)
-            try:
-                crq.send(data, routing_key=routing_key)
-            finally:
-                crq.close()
+            mailbox(connection).publish_reply(data, exchange, routing_key)
 
         self.app.with_default_connection(_do_reply)(**kwargs)
 

+ 14 - 11
celery/worker/listener.py

@@ -77,7 +77,7 @@ from celery.app import app_or_default
 from celery.datastructures import SharedCounter
 from celery.events import EventDispatcher
 from celery.exceptions import NotRegistered
-from celery.pidbox import BroadcastConsumer
+from celery.pidbox import mailbox
 from celery.utils import noop
 from celery.utils.timer2 import to_timestamp
 from celery.worker.job import TaskRequest, InvalidTaskError
@@ -239,6 +239,8 @@ class CarrotListener(object):
     def consume_messages(self):
         """Consume messages forever (or until an exception is raised)."""
         self.logger.debug("CarrotListener: Starting message consumer...")
+        self.task_consumer.consume()
+        self.broadcast_consumer.consume()
         wait_for_message = self._mainloop().next
         self.logger.debug("CarrotListener: Ready to accept tasks!")
 
@@ -358,6 +360,9 @@ class CarrotListener(object):
             self.event_dispatcher = \
                     self.maybe_conn_error(self.event_dispatcher.close)
 
+        if self.broadcast_consumer:
+            self.broadcast_consumer.channel.close()
+
         if close:
             self.close_connection()
 
@@ -386,20 +391,20 @@ class CarrotListener(object):
 
         self.connection = self._open_connection()
         self.logger.debug("CarrotListener: Connection Established.")
-        self.task_consumer = self.app.amqp.get_task_consumer(
-                                        connection=self.connection,
-                                        queues=self.queues)
+        self.task_consumer = self.app.amqp.get_task_consumer(self.connection,
+                                                          queues=self.queues)
         # QoS: Reset prefetch window.
         self.qos = QoS(self.task_consumer,
                        self.initial_prefetch_count, self.logger)
         self.qos.update()                   # enable prefetch_count
 
         self.task_consumer.on_decode_error = self.on_decode_error
-        self.broadcast_consumer = BroadcastConsumer(self.connection,
-                                                    app=self.app,
-                                                    hostname=self.hostname)
         self.task_consumer.register_callback(self.receive_message)
 
+        self.broadcast_consumer = mailbox(self.connection).get_consumer(
+                                        self.hostname)
+        self.broadcast_consumer.register_callback(self.receive_message)
+
         # Flush events sent while connection was down.
         if self.event_dispatcher:
             self.event_dispatcher.flush()
@@ -416,9 +421,6 @@ class CarrotListener(object):
         self.heart.start()
 
     def _mainloop(self):
-        elf.broadcast_consumer.register_callback(self.receive_message)
-        self.task_consumer.consume()
-        self.broadcast_consumer.consume()
         while 1:
             yield self.connection.drain_events()
 
@@ -433,7 +435,8 @@ class CarrotListener(object):
 
         conn = self.app.broker_connection()
         if not self.app.conf.BROKER_CONNECTION_RETRY:
-            return conn.connect()
+            conn.connect()
+            return conn
 
         return conn.ensure_connection(_connection_error_handler,
                     self.app.conf.BROKER_CONNECTION_MAX_RETRIES)