Explorar o código

Consumer: cosmetics and use moar mock.Mock

Ask Solem %!s(int64=14) %!d(string=hai) anos
pai
achega
e847ec1d17
Modificáronse 2 ficheiros con 190 adicións e 88 borrados
  1. 16 12
      celery/tests/test_worker/test_worker.py
  2. 174 76
      celery/worker/consumer.py

+ 16 - 12
celery/tests/test_worker/test_worker.py

@@ -1,6 +1,7 @@
 import socket
 import sys
 
+from collections import deque
 from datetime import datetime, timedelta
 from Queue import Empty
 
@@ -212,7 +213,7 @@ class test_Consumer(unittest.TestCase):
 
         l._state = RUN
         l.event_dispatcher = None
-        l.stop_consumers(close=False)
+        l.stop_consumers(close_connection=False)
         self.assertTrue(l.connection)
 
         l._state = RUN
@@ -237,11 +238,12 @@ class test_Consumer(unittest.TestCase):
 
         l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                            send_events=False)
-        eventer = l.event_dispatcher = MockEventDispatcher()
+        eventer = l.event_dispatcher = Mock()
+        eventer.enabled = True
         heart = l.heart = MockHeart()
         l._state = RUN
         l.stop_consumers()
-        self.assertTrue(eventer.closed)
+        self.assertTrue(eventer.close.call_count)
         self.assertTrue(heart.closed)
 
     def test_receive_message_unknown(self):
@@ -249,7 +251,7 @@ class test_Consumer(unittest.TestCase):
                            send_events=False)
         backend = Mock()
         m = create_message(backend, unknown={"baz": "!!!"})
-        l.event_dispatcher = MockEventDispatcher()
+        l.event_dispatcher = Mock()
         l.pidbox_node = MockNode()
 
         def with_catch_warnings(log):
@@ -274,7 +276,7 @@ class test_Consumer(unittest.TestCase):
                                     args=("2, 2"),
                                     kwargs={},
                                     eta=datetime.now().isoformat())
-        l.event_dispatcher = MockEventDispatcher()
+        l.event_dispatcher = Mock()
         l.pidbox_node = MockNode()
 
         prev, timer2.to_timestamp = timer2.to_timestamp, to_timestamp
@@ -292,11 +294,12 @@ class test_Consumer(unittest.TestCase):
         backend = Mock()
         m = create_message(backend, task=foo_task.name,
             args=(1, 2), kwargs="foobarbaz", id=1)
-        l.event_dispatcher = MockEventDispatcher()
+        l.event_dispatcher = Mock()
         l.pidbox_node = MockNode()
 
         l.receive_message(m.decode(), m)
-        self.assertIn("Invalid task ignored", logger.error.call_args[0][0])
+        self.assertIn("Received invalid task message",
+                      logger.error.call_args[0][0])
 
     def test_on_decode_error(self):
         logger = Mock()
@@ -325,7 +328,7 @@ class test_Consumer(unittest.TestCase):
         m = create_message(backend, task=foo_task.name,
                            args=[2, 4, 8], kwargs={})
 
-        l.event_dispatcher = MockEventDispatcher()
+        l.event_dispatcher = Mock()
         l.receive_message(m.decode(), m)
 
         in_bucket = self.ready_queue.get_nowait()
@@ -436,7 +439,7 @@ class test_Consumer(unittest.TestCase):
 
         l.task_consumer = MockConsumer()
         l.qos = QoS(l.task_consumer, l.initial_prefetch_count, l.logger)
-        l.event_dispatcher = MockEventDispatcher()
+        l.event_dispatcher = Mock()
         l.receive_message(m.decode(), m)
         l.eta_schedule.stop()
 
@@ -469,7 +472,7 @@ class test_Consumer(unittest.TestCase):
         backend = Mock()
         m = create_message(backend, task="x.X.31x", args=[2, 4, 8], kwargs={})
 
-        l.event_dispatcher = MockEventDispatcher()
+        l.event_dispatcher = Mock()
         self.assertFalse(l.receive_message(m.decode(), m))
         self.assertRaises(Empty, self.ready_queue.get_nowait)
         self.assertTrue(self.eta_schedule.empty())
@@ -477,7 +480,8 @@ class test_Consumer(unittest.TestCase):
     def test_receieve_message_eta(self):
         l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
-        l.event_dispatcher = MockEventDispatcher()
+        l.event_dispatcher = Mock()
+        l.event_dispatcher._outbound_buffer = deque()
         backend = Mock()
         m = create_message(backend, task=foo_task.name,
                            args=[2, 4, 8], kwargs={},
@@ -492,7 +496,7 @@ class test_Consumer(unittest.TestCase):
         finally:
             l.app.conf.BROKER_CONNECTION_RETRY = p
         l.stop_consumers()
-        l.event_dispatcher = MockEventDispatcher()
+        l.event_dispatcher = Mock()
         l.receive_message(m.decode(), m)
         l.eta_schedule.stop()
         in_hold = self.eta_schedule.queue[0]

+ 174 - 76
celery/worker/consumer.py

@@ -93,6 +93,31 @@ CLOSE = 0x2
 #: Prefetch count can't exceed short.
 PREFETCH_COUNT_MAX = 0xFFFF
 
+#: Error message for when an unregistered task is received.
+UNKNOWN_TASK_ERROR = """\
+Received unregistered task of type %s.
+The message has been ignored and discarded.
+
+Did you remember to import the module containing this task?
+Or maybe you are using relative imports?
+Please see http://bit.ly/gLye1c for more information.
+
+The full contents of the message body was:
+%s
+"""
+
+#: Error message for when an invalid task message is received.
+INVALID_TASK_ERROR = """\
+Received invalid task message: %s
+The message has been ignored and discarded.
+
+Please ensure your message conforms to the task
+message protocol as described here: http://bit.ly/hYj41y
+
+The full contents of the message body was:
+%s
+"""
+
 
 class QoS(object):
     """Quality of Service for Channel.
@@ -182,50 +207,59 @@ class Consumer(object):
     :param ready_queue: See :attr:`ready_queue`.
     :param eta_schedule: See :attr:`eta_schedule`.
 
-    .. attribute:: ready_queue
-
-        The queue that holds tasks ready for immediate processing.
-
-    .. attribute:: eta_schedule
-
-        Scheduler for paused tasks. Reasons for being paused include
-        a countdown/eta or that it's waiting for retry.
-
-    .. attribute:: send_events
+    """
 
-        Is events enabled?
+    #: The queue that holds tasks ready for immediate processing.
+    ready_queue = None
 
-    .. attribute:: init_callback
+    #: Timer for tasks with an ETA/countdown.
+    eta_schedule = None
 
-        Callback to be called the first time the connection is active.
+    #: Enable/disable events.
+    send_events = False
 
-    .. attribute:: hostname
+    #: Optional callback to be called when the connection is established.
+    #: Will only be called once, even if the connection is lost and
+    #: re-established.
+    init_callback = None
 
-        Current hostname. Defaults to the system hostname.
+    #: The current hostname.  Defaults to the system hostname.
+    hostname = None
 
-    .. attribute:: initial_prefetch_count
+    #: Initial QoS prefetch count for the task channel.
+    initial_prefetch_count = 0
 
-        Initial QoS prefetch count for the task channel.
+    #: A :class:`celery.events.EventDispatcher` for sending events.
+    event_dispatcher = None
 
-    .. attribute:: control_dispatch
+    #: The thread that sends event heartbeats at regular intervals.
+    #: The heartbeats are used by monitors to detect that a worker
+    #: went offline/disappeared.
+    heart = None
 
-        Control command dispatcher.
-        See :class:`celery.worker.control.ControlDispatch`.
+    #: The logger instance to use.  Defaults to the default Celery logger.
+    logger = None
 
-    .. attribute:: event_dispatcher
+    #: The broker connection.
+    connection = None
 
-        See :class:`celery.events.EventDispatcher`.
+    #: The consumer used to consume task messages.
+    task_consumer = None
 
-    .. attribute:: hart
+    #: The consumer used to consume broadcast commands.
+    broadcast_consumer = None
 
-        :class:`~celery.worker.heartbeat.Heart` sending out heart beats
-        if events enabled.
+    #: The process mailbox (kombu pidbox node).
+    pidbox_node = None
 
-    .. attribute:: logger
+    #: The current worker pool instance.
+    pool = None
 
-        The logger used.
+    #: A timer used for high-priority internal tasks, such 
+    #: as sending heartbeats.
+    priority_timer = None
 
-    """
+    # Consumer state, can be RUN or CLOSE.
     _state = None
 
     def __init__(self, ready_queue, eta_schedule, logger,
@@ -262,8 +296,9 @@ class Consumer(object):
     def start(self):
         """Start the consumer.
 
-        If the connection is lost, it tries to re-establish the connection
-        and restarts consuming messages.
+        Automatically surivives intermittent connection failure,
+        and will retry establishing the connection and restart
+        consuming messages.
 
         """
 
@@ -280,9 +315,9 @@ class Consumer(object):
 
     def consume_messages(self):
         """Consume messages forever (or until an exception is raised)."""
-        self.logger.debug("Consumer: Starting message consumer...")
+        self._debug("Starting message consumer...")
         self.task_consumer.consume()
-        self.logger.debug("Consumer: Ready to accept tasks!")
+        self._debug("Ready to accept tasks!")
 
         while self._state != CLOSE and self.connection:
             if self.qos.prev != self.qos.value:
@@ -308,11 +343,12 @@ class Consumer(object):
 
         self.logger.info("Got task from broker: %s" % (task.shortinfo(), ))
 
-        self.event_dispatcher.send("task-received", uuid=task.task_id,
-                name=task.task_name, args=safe_repr(task.args),
-                kwargs=safe_repr(task.kwargs), retries=task.retries,
-                eta=task.eta and task.eta.isoformat(),
-                expires=task.expires and task.expires.isoformat())
+        if self.event_dispatcher.enabled:
+            self.event_dispatcher.send("task-received", uuid=task.task_id,
+                    name=task.task_name, args=safe_repr(task.args),
+                    kwargs=safe_repr(task.kwargs), retries=task.retries,
+                    eta=task.eta and task.eta.isoformat(),
+                    expires=task.expires and task.expires.isoformat())
 
         if task.eta:
             try:
@@ -332,6 +368,7 @@ class Consumer(object):
             self.ready_queue.put(task)
 
     def on_control(self, body, message):
+        """Process remote control command message."""
         try:
             self.pidbox_node.handle_message(body, message)
         except KeyError, exc:
@@ -343,22 +380,30 @@ class Consumer(object):
             self.reset_pidbox_node()
 
     def apply_eta_task(self, task):
+        """Method called by the timer to apply a task with an
+        ETA/countdown."""
         state.task_reserved(task)
         self.ready_queue.put(task)
         self.qos.decrement_eventually()
 
     def receive_message(self, body, message):
-        """The callback called when a new message is received. """
+        """Handles incoming messages.
+
+        :param body: The message body.
+        :param message: The kombu message object.
+
+        """
 
         # Handle task
         if body.get("task"):
+            # need to guard against errors occuring while acking the message.
             def ack():
                 try:
                     message.ack()
                 except self.connection_errors + (AttributeError, ), exc:
                     self.logger.critical(
-                            "Couldn't ack %r: message:%r reason:%r" % (
-                                message.delivery_tag, body, exc))
+                        "Couldn't ack %r: body:%r reason:%r" % (
+                            message.delivery_tag, safe_str(body), exc))
 
             try:
                 task = TaskRequest.from_message(message, body, ack,
@@ -366,13 +411,14 @@ class Consumer(object):
                                                 logger=self.logger,
                                                 hostname=self.hostname,
                                                 eventer=self.event_dispatcher)
+
             except NotRegistered, exc:
-                self.logger.error("Unknown task ignored: %r Body->%r" % (
-                        exc, body), exc_info=sys.exc_info())
+                self.logger.error(UNKNOWN_TASK_ERROR % (
+                        exc, safe_str(body)), exc_info=sys.exc_info())
                 message.ack()
             except InvalidTaskError, exc:
-                self.logger.error("Invalid task ignored: %s: %s" % (
-                        str(exc), body), exc_info=sys.exc_info())
+                self.logger.error(INVALID_TASK_ERROR % (
+                        str(exc), safe_str(body)), exc_info=sys.exc_info())
                 message.ack()
             else:
                 self.on_task(task)
@@ -380,10 +426,13 @@ class Consumer(object):
 
         warnings.warn(RuntimeWarning(
             "Received and deleted unknown message. Wrong destination?!? \
-             the message was: %s" % body))
+             the full contents of the message body was: %s" % (
+                 safe_str(body), )))
         message.ack()
 
     def maybe_conn_error(self, fun):
+        """Applies function but ignores any connection or channel
+        errors raised."""
         try:
             fun()
         except (AttributeError, ) + \
@@ -392,59 +441,73 @@ class Consumer(object):
             pass
 
     def close_connection(self):
+        """Closes the current broker connection and all open channels."""
         if self.task_consumer:
-            self.logger.debug("Consumer: " "Closing consumer channel...")
+            self._debug("Closing consumer channel...")
             self.task_consumer = \
                     self.maybe_conn_error(self.task_consumer.close)
+
         if self.broadcast_consumer:
-            self.logger.debug("CarrotListener: Closing broadcast channel...")
+            self._debug("Closing broadcast channel...")
             self.broadcast_consumer = \
                 self.maybe_conn_error(self.broadcast_consumer.channel.close)
 
         if self.connection:
-            self.logger.debug("Consumer: " "Closing connection to broker...")
+            self._debug("Closing broker connection...")
             self.connection = self.maybe_conn_error(self.connection.close)
 
-    def stop_consumers(self, close=True):
-        """Stop consuming."""
+    def stop_consumers(self, close_connection=True):
+        """Stop consuming tasks and broadcast commands, also stops
+        the heartbeat thread and event dispatcher.
+
+        :keyword close_connection: Set to False to skip closing the broker
+                                    connection.
+
+        """
         if not self._state == RUN:
             return
 
         if self.heart:
+            # Stop the heartbeat thread if it's running.
             self.logger.debug("Heart: Going into cardiac arrest...")
             self.heart = self.heart.stop()
 
-        self.logger.debug("TaskConsumer: Cancelling consumers...")
+        self._debug("Cancelling task consumer...")
         if self.task_consumer:
             self.maybe_conn_error(self.task_consumer.cancel)
 
         if self.event_dispatcher:
-            self.logger.debug("EventDispatcher: Shutting down...")
+            self._debug("Shutting down event dispatcher...")
             self.event_dispatcher = \
                     self.maybe_conn_error(self.event_dispatcher.close)
 
-        self.logger.debug("BroadcastConsumer: Cancelling consumer...")
+        self._debug("Cancelling broadcast consumer...")
         if self.broadcast_consumer:
             self.maybe_conn_error(self.broadcast_consumer.cancel)
 
-        if close:
+        if close_connection:
             self.close_connection()
 
     def on_decode_error(self, message, exc):
-        """Callback called if the message had decoding errors.
+        """Callback called if an error occurs while decoding
+        a message received.
+
+        Simply logs the error and acknowledges the message so it
+        doesn't enter a loop.
 
         :param message: The message with errors.
         :param exc: The original exception instance.
 
         """
-        self.logger.critical("Can't decode message body: %r "
-                             "(type:%r encoding:%r raw:%r')" % (
-                                exc, message.content_type,
-                                message.content_encoding,
-                                safe_str(message.body)))
+        self.logger.critical(
+            "Can't decode message body: %r (type:%r encoding:%r raw:%r')" % (
+                    exc, message.content_type, message.content_encoding,
+                    safe_str(message.body)))
         message.ack()
 
     def reset_pidbox_node(self):
+        """Sets up the process mailbox."""
+        # close previously opened channel if any.
         if self.pidbox_node.channel:
             try:
                 self.pidbox_node.channel.close()
@@ -459,6 +522,8 @@ class Consumer(object):
         self.broadcast_consumer.consume()
 
     def _green_pidbox_node(self):
+        """Sets up the process mailbox when running in a greenlet
+        environment."""
         conn = self._open_connection()
         self.pidbox_node.channel = conn.channel()
         self.broadcast_consumer = self.pidbox_node.listen(
@@ -472,27 +537,31 @@ class Consumer(object):
             conn.close()
 
     def reset_connection(self):
-        """Re-establish connection and set up consumers."""
-        self.logger.debug(
-                "Consumer: Re-establishing connection to the broker...")
+        """Re-establish the broker connection and set up consumers,
+        heartbeat and the event dispatcher."""
+        self._debug("Re-establishing connection to the broker...")
         self.stop_consumers()
 
-        # Clear internal queues.
+        # Clear internal queues to get rid of old messages.
+        # They can't be acked anyway, as a delivery tag is specific
+        # to the current channel.
         self.ready_queue.clear()
         self.eta_schedule.clear()
 
+        # Re-establish the broker connection and setup the task consumer.
         self.connection = self._open_connection()
-        self.logger.debug("Consumer: Connection Established.")
+        self._debug("Connection established.")
         self.task_consumer = self.app.amqp.get_task_consumer(self.connection,
                                     on_decode_error=self.on_decode_error)
         # QoS: Reset prefetch window.
         self.qos = QoS(self.task_consumer,
                        self.initial_prefetch_count, self.logger)
-        self.qos.update()                   # enable prefetch_count
+        self.qos.update()
 
+        # receive_message handles incomsing messages.
         self.task_consumer.register_callback(self.receive_message)
 
-        # Pidbox
+        # Setup the process mailbox.
         self.reset_pidbox_node()
 
         # Flush events sent while connection was down.
@@ -504,25 +573,41 @@ class Consumer(object):
             self.event_dispatcher.copy_buffer(prev_event_dispatcher)
             self.event_dispatcher.flush()
 
+        # Restart heartbeat thread.
         self.restart_heartbeat()
 
+        # We're back!
         self._state = RUN
 
     def restart_heartbeat(self):
+        """Restart the heartbeat thread.
+
+        This thread sends heartbeat events at intervals so monitors
+        can tell if the worker is offline/missing.
+
+        """
         self.heart = Heart(self.priority_timer, self.event_dispatcher)
         self.heart.start()
 
     def _open_connection(self):
-        """Open connection.  May retry opening the connection if configuration
-        allows that."""
+        """Establish the broker connection.
+
+        Will retry establishing the connection if the
+        :setting:`BROKER_CONNECTION_RETRY` setting is enabled
+
+        """
 
         def _connection_error_handler(exc, interval):
-            """Callback handler for connection errors."""
+            # Callback called for each retry when the connection
+            # can't be established.
             self.logger.error("Consumer: Connection Error: %s. " % exc
-                     + "Trying again in %d seconds..." % interval)
+                            + "Trying again in %d seconds..." % interval)
 
+        # remember that the connection is lazy, it won't establish
+        # until it's needed.
         conn = self.app.broker_connection()
         if not self.app.conf.BROKER_CONNECTION_RETRY:
+            # retry disabled, just call connect directly.
             conn.connect()
             return conn
 
@@ -532,18 +617,31 @@ class Consumer(object):
     def stop(self):
         """Stop consuming.
 
-        Does not close connection.
+        Does not close the broker connection, so be sure to call
+        :meth:`close_connection` when you are finished with it.
 
         """
+        # Notifies other threads that this instance can't be used
+        # anymore.
         self._state = CLOSE
-        self.logger.debug("Consumer: Stopping consumers...")
-        self.stop_consumers(close=False)
+        self._debug("Stopping consumers...")
+        self.stop_consumers(close_connection=False)
 
     @property
     def info(self):
+        """Returns information about this consumer instance
+        as a dict.
+
+        This is also the consumer related info returned by
+        ``celeryctl stats``.
+
+        """
         conninfo = {}
         if self.connection:
             conninfo = self.connection.info()
             conninfo.pop("password", None)  # don't send password.
-        return {"broker": conninfo,
-                "prefetch_count": self.qos.value}
+        return {"broker": conninfo, "prefetch_count": self.qos.value}
+
+    def _debug(self, msg, **kwargs):
+        self.logger.debug("Consumer: %s" % (msg, ), **kwargs)
+