Prechádzať zdrojové kódy

Initial Consumer Bootsteps implementation

Ask Solem 12 rokov pred
rodič
commit
29941b2e4e

+ 7 - 0
celery/utils/text.py

@@ -13,6 +13,8 @@ from textwrap import fill
 
 
 from pprint import pformat
 from pprint import pformat
 
 
+from kombu.utils.encoding import safe_repr
+
 
 
 def dedent_initial(s, n=4):
 def dedent_initial(s, n=4):
     return s[n:] if s[:n] == ' ' * n else s
     return s[n:] if s[:n] == ' ' * n else s
@@ -79,3 +81,8 @@ def pretty(value, width=80, nl_width=80, **kw):
         return '\n{0}{1}'.format(' ' * 4, pformat(value, width=nl_width, **kw))
         return '\n{0}{1}'.format(' ' * 4, pformat(value, width=nl_width, **kw))
     else:
     else:
         return pformat(value, width=width, **kw)
         return pformat(value, width=width, **kw)
+
+
+def dump_body(m, body):
+    return '{0} ({1}b)'.format(truncate(safe_repr(body), 1024),
+                               len(m.body))

+ 1 - 1
celery/worker/__init__.py

@@ -144,7 +144,7 @@ class WorkController(configurated):
 
 
     def on_stopped(self):
     def on_stopped(self):
         self.timer.stop()
         self.timer.stop()
-        self.consumer.close_connection()
+        self.consumer.shutdown()
 
 
         if self.pidlock:
         if self.pidlock:
             self.pidlock.release()
             self.pidlock.release()

+ 22 - 9
celery/worker/bootsteps.py

@@ -86,11 +86,20 @@ class Namespace(object):
             else:
             else:
                 close(parent)
                 close(parent)
 
 
-    def stop(self, parent, terminate=False):
-        what = 'Terminating' if terminate else 'Stopping'
+    def restart(self, parent, description='Restarting', terminate=False):
         socket_timeout = socket.getdefaulttimeout()
         socket_timeout = socket.getdefaulttimeout()
         socket.setdefaulttimeout(SHUTDOWN_SOCKET_TIMEOUT)  # Issue 975
         socket.setdefaulttimeout(SHUTDOWN_SOCKET_TIMEOUT)  # Issue 975
-
+        try:
+            for component in reversed(parent.components):
+                if component:
+                    logger.debug('%s %s...', description, qualname(component))
+                    (component.terminate if terminate
+                        else component.stop)(parent)
+        finally:
+            socket.setdefaulttimeout(socket_timeout)
+
+    def stop(self, parent, close=True, terminate=False):
+        what = 'Terminating' if terminate else 'Stopping'
         if self.state in (CLOSE, TERMINATE):
         if self.state in (CLOSE, TERMINATE):
             return
             return
 
 
@@ -102,16 +111,11 @@ class Namespace(object):
             self.shutdown_complete.set()
             self.shutdown_complete.set()
             return
             return
         self.state = CLOSE
         self.state = CLOSE
-
-        for component in reversed(parent.components):
-            if component:
-                logger.debug('%s %s...', what, qualname(component))
-                (component.terminate if terminate else component.stop)(parent)
+        self.restart(parent, what, terminate)
 
 
         if self.on_stopped:
         if self.on_stopped:
             self.on_stopped()
             self.on_stopped()
         self.state = TERMINATE
         self.state = TERMINATE
-        socket.setdefaulttimeout(socket_timeout)
         self.shutdown_complete.set()
         self.shutdown_complete.set()
 
 
     def join(self, timeout=None):
     def join(self, timeout=None):
@@ -191,6 +195,13 @@ class Namespace(object):
                             *(self.name.capitalize(), ) + args)
                             *(self.name.capitalize(), ) + args)
 
 
 
 
+def _prepare_requires(req):
+    if not isinstance(req, basestring):
+        req = req.name
+    return req
+
+
+
 class ComponentType(type):
 class ComponentType(type):
     """Metaclass for components."""
     """Metaclass for components."""
 
 
@@ -204,6 +215,8 @@ class ComponentType(type):
             namespace = attrs.get('namespace', None)
             namespace = attrs.get('namespace', None)
             if not namespace:
             if not namespace:
                 attrs['namespace'], _, attrs['name'] = cname.partition('.')
                 attrs['namespace'], _, attrs['name'] = cname.partition('.')
+        attrs['requires'] = tuple(_prepare_requires(req)
+                                    for req in attrs.get('requires', ()))
         cls = super(ComponentType, cls).__new__(cls, name, bases, attrs)
         cls = super(ComponentType, cls).__new__(cls, name, bases, attrs)
         if not abstract:
         if not abstract:
             Namespace._unclaimed[cls.namespace][cls.name] = cls
             Namespace._unclaimed[cls.namespace][cls.name] = cls

+ 220 - 398
celery/worker/consumer.py

@@ -12,34 +12,53 @@ from __future__ import absolute_import
 
 
 import logging
 import logging
 import socket
 import socket
-import threading
 
 
 from time import sleep
 from time import sleep
 from Queue import Empty
 from Queue import Empty
 
 
-from kombu.common import QoS
 from kombu.syn import _detect_environment
 from kombu.syn import _detect_environment
 from kombu.utils.encoding import safe_repr
 from kombu.utils.encoding import safe_repr
 from kombu.utils.eventio import READ, WRITE, ERR
 from kombu.utils.eventio import READ, WRITE, ERR
 
 
 from celery.app import app_or_default
 from celery.app import app_or_default
-from celery.datastructures import AttributeDict
 from celery.exceptions import InvalidTaskError, SystemTerminate
 from celery.exceptions import InvalidTaskError, SystemTerminate
 from celery.task.trace import build_tracer
 from celery.task.trace import build_tracer
-from celery.utils import text
-from celery.utils import timer2
+from celery.utils.timer2 import default_timer, to_timestamp
 from celery.utils.functional import noop
 from celery.utils.functional import noop
+from celery.utils.imports import qualname
 from celery.utils.log import get_logger
 from celery.utils.log import get_logger
+from celery.utils.text import dump_body
 from celery.utils.timeutils import humanize_seconds
 from celery.utils.timeutils import humanize_seconds
 
 
 from . import state
 from . import state
-from .bootsteps import StartStopComponent, RUN, CLOSE
-from .control import Panel
-from .heartbeat import Heart
+from .bootsteps import Namespace as _NS, StartStopComponent, CLOSE
+
+logger = get_logger(__name__)
+info, warn, error, crit = (logger.info, logger.warn,
+                           logger.error, logger.critical)
+task_reserved = state.task_reserved
 
 
 #: Heartbeat check is called every heartbeat_seconds' / rate'.
 #: Heartbeat check is called every heartbeat_seconds' / rate'.
 AMQHEARTBEAT_RATE = 2.0
 AMQHEARTBEAT_RATE = 2.0
 
 
+CONNECTION_RETRY = """\
+consumer: Connection to broker lost. \
+Trying to re-establish the connection...\
+"""
+
+CONNECTION_RETRY_STEP = """\
+Trying again {when}...\
+"""
+
+CONNECTION_ERROR = """\
+consumer: Cannot connect to %s: %s.
+%s
+"""
+
+CONNECTION_FAILOVER = """\
+Will retry using next failover.\
+"""
+
 UNKNOWN_FORMAT = """\
 UNKNOWN_FORMAT = """\
 Received and deleted unknown message. Wrong destination?!?
 Received and deleted unknown message. Wrong destination?!?
 
 
@@ -76,40 +95,10 @@ body: {0} {{content_type:{1} content_encoding:{2} delivery_info:{3}}}\
 """
 """
 
 
 
 
-RETRY_CONNECTION = """\
-consumer: Connection to broker lost. \
-Trying to re-establish the connection...\
-"""
-
-CONNECTION_ERROR = """\
-consumer: Cannot connect to %s: %s.
-%s
-"""
-
-CONNECTION_RETRY = """\
-Trying again {when}...\
-"""
-
-CONNECTION_FAILOVER = """\
-Will retry using next failover.\
-"""
-
-task_reserved = state.task_reserved
-
-logger = get_logger(__name__)
-info, warn, error, crit = (logger.info, logger.warn,
-                           logger.error, logger.critical)
-
-
 def debug(msg, *args, **kwargs):
 def debug(msg, *args, **kwargs):
     logger.debug('consumer: {0}'.format(msg), *args, **kwargs)
     logger.debug('consumer: {0}'.format(msg), *args, **kwargs)
 
 
 
 
-def dump_body(m, body):
-    return '{0} ({1}b)'.format(text.truncate(safe_repr(body), 1024),
-                               len(m.body))
-
-
 class Component(StartStopComponent):
 class Component(StartStopComponent):
     name = 'worker.consumer'
     name = 'worker.consumer'
     last = True
     last = True
@@ -134,6 +123,28 @@ class Component(StartStopComponent):
         return c
         return c
 
 
 
 
+class Namespace(_NS):
+    name = 'consumer'
+
+    def shutdown(self, parent):
+        delayed = self._shutdown_step(parent, parent.components, force=False)
+        self._shutdown_step(parent, delayed, force=True)
+
+    def _shutdown_step(self, parent, components, force=False):
+        delayed = []
+        for component in components:
+            if component:
+                logger.debug('Shutdown %s...', qualname(component))
+                if not force and getattr(component, 'delay_shutdown', False):
+                    delayed.append(component)
+                else:
+                    component.shutdown(parent)
+        return delayed
+
+    def modules(self):
+        return ('celery.worker.parts', )
+
+
 class Consumer(object):
 class Consumer(object):
     """Listen for messages received from the broker and
     """Listen for messages received from the broker and
     move them to the ready queue for task processing.
     move them to the ready queue for task processing.
@@ -146,9 +157,6 @@ class Consumer(object):
     #: The queue that holds tasks ready for immediate processing.
     #: The queue that holds tasks ready for immediate processing.
     ready_queue = None
     ready_queue = None
 
 
-    #: Enable/disable events.
-    send_events = False
-
     #: Optional callback to be called when the connection is established.
     #: Optional callback to be called when the connection is established.
     #: Will only be called once, even if the connection is lost and
     #: Will only be called once, even if the connection is lost and
     #: re-established.
     #: re-established.
@@ -157,31 +165,6 @@ class Consumer(object):
     #: The current hostname.  Defaults to the system hostname.
     #: The current hostname.  Defaults to the system hostname.
     hostname = None
     hostname = None
 
 
-    #: Initial QoS prefetch count for the task channel.
-    initial_prefetch_count = 0
-
-    #: A :class:`celery.events.EventDispatcher` for sending events.
-    event_dispatcher = None
-
-    #: The thread that sends event heartbeats at regular intervals.
-    #: The heartbeats are used by monitors to detect that a worker
-    #: went offline/disappeared.
-    heart = None
-
-    #: The broker connection.
-    connection = None
-
-    #: The consumer used to consume task messages.
-    task_consumer = None
-
-    #: The consumer used to consume broadcast commands.
-    broadcast_consumer = None
-
-    #: The process mailbox (kombu pidbox node).
-    pidbox_node = None
-    _pidbox_node_shutdown = None   # used for greenlets
-    _pidbox_node_stopped = None    # used for greenlets
-
     #: The current worker pool instance.
     #: The current worker pool instance.
     pool = None
     pool = None
 
 
@@ -189,41 +172,24 @@ class Consumer(object):
     #: as sending heartbeats.
     #: as sending heartbeats.
     timer = None
     timer = None
 
 
-    # Consumer state, can be RUN or CLOSE.
-    _state = None
-
     def __init__(self, ready_queue,
     def __init__(self, ready_queue,
-            init_callback=noop, send_events=False, hostname=None,
-            initial_prefetch_count=2, pool=None, app=None,
+            init_callback=noop, hostname=None,
+            pool=None, app=None,
             timer=None, controller=None, hub=None, amqheartbeat=None,
             timer=None, controller=None, hub=None, amqheartbeat=None,
             **kwargs):
             **kwargs):
         self.app = app_or_default(app)
         self.app = app_or_default(app)
-        self.connection = None
-        self.task_consumer = None
         self.controller = controller
         self.controller = controller
-        self.broadcast_consumer = None
         self.ready_queue = ready_queue
         self.ready_queue = ready_queue
-        self.send_events = send_events
         self.init_callback = init_callback
         self.init_callback = init_callback
         self.hostname = hostname or socket.gethostname()
         self.hostname = hostname or socket.gethostname()
-        self.initial_prefetch_count = initial_prefetch_count
-        self.event_dispatcher = None
-        self.heart = None
         self.pool = pool
         self.pool = pool
-        self.timer = timer or timer2.default_timer
-        pidbox_state = AttributeDict(app=self.app,
-                                     hostname=self.hostname,
-                                     listener=self,     # pre 2.2
-                                     consumer=self)
-        self.pidbox_node = self.app.control.mailbox.Node(self.hostname,
-                                                         state=pidbox_state,
-                                                         handlers=Panel.data)
+        self.timer = timer or default_timer
+        self.strategies = {}
         conninfo = self.app.connection()
         conninfo = self.app.connection()
         self.connection_errors = conninfo.connection_errors
         self.connection_errors = conninfo.connection_errors
         self.channel_errors = conninfo.channel_errors
         self.channel_errors = conninfo.channel_errors
 
 
         self._does_info = logger.isEnabledFor(logging.INFO)
         self._does_info = logger.isEnabledFor(logging.INFO)
-        self.strategies = {}
         if hub:
         if hub:
             hub.on_init.append(self.on_poll_init)
             hub.on_init.append(self.on_poll_init)
         self.hub = hub
         self.hub = hub
@@ -240,14 +206,16 @@ class Consumer(object):
             # connect again.
             # connect again.
             self.app.conf.BROKER_CONNECTION_TIMEOUT = None
             self.app.conf.BROKER_CONNECTION_TIMEOUT = None
 
 
-    def update_strategies(self):
-        S = self.strategies
-        app = self.app
-        loader = app.loader
-        hostname = self.hostname
-        for name, task in self.app.tasks.iteritems():
-            S[name] = task.start_strategy(app, self)
-            task.__trace__ = build_tracer(name, task, loader, hostname)
+        self.components = []
+        self.namespace = Namespace(app=self.app,
+                                   on_start=self.on_start,
+                                   on_close=self.on_close)
+        self.namespace.apply(self, **kwargs)
+
+    def on_start(self):
+        # reload all task's execution strategies.
+        self.update_strategies()
+        self.init_callback(self)
 
 
     def start(self):
     def start(self):
         """Start the consumer.
         """Start the consumer.
@@ -257,26 +225,144 @@ class Consumer(object):
         consuming messages.
         consuming messages.
 
 
         """
         """
-
-        self.init_callback(self)
-
-        while self._state != CLOSE:
+        ns = self.namespace
+        while ns.state != CLOSE:
             self.maybe_shutdown()
             self.maybe_shutdown()
             try:
             try:
-                self.reset_connection()
+                self.namespace.start(self)
                 self.consume_messages()
                 self.consume_messages()
             except self.connection_errors + self.channel_errors:
             except self.connection_errors + self.channel_errors:
-                error(RETRY_CONNECTION, exc_info=True)
+                error(CONNECTION_RETRY, exc_info=True)
+                ns.restart(self)
+            ns.close(self)
+            ns.state = CLOSE
 
 
     def on_poll_init(self, hub):
     def on_poll_init(self, hub):
         hub.update_readers(self.connection.eventmap)
         hub.update_readers(self.connection.eventmap)
         self.connection.transport.on_poll_init(hub.poller)
         self.connection.transport.on_poll_init(hub.poller)
 
 
+    def maybe_conn_error(self, fun):
+        """Applies function but ignores any connection or channel
+        errors raised."""
+        try:
+            fun()
+        except (AttributeError, ) + \
+                self.connection_errors + \
+                self.channel_errors:
+            pass
+
+    def shutdown(self):
+        self.namespace.shutdown(self)
+
+    def on_decode_error(self, message, exc):
+        """Callback called if an error occurs while decoding
+        a message received.
+
+        Simply logs the error and acknowledges the message so it
+        doesn't enter a loop.
+
+        :param message: The message with errors.
+        :param exc: The original exception instance.
+
+        """
+        crit("Can't decode message body: %r (type:%r encoding:%r raw:%r')",
+             exc, message.content_type, message.content_encoding,
+             dump_body(message, message.body))
+        message.ack()
+
+    def on_close(self):
+        # Clear internal queues to get rid of old messages.
+        # They can't be acked anyway, as a delivery tag is specific
+        # to the current channel.
+        self.ready_queue.clear()
+        self.timer.clear()
+
+    def _open_connection(self):
+        """Establish the broker connection.
+
+        Will retry establishing the connection if the
+        :setting:`BROKER_CONNECTION_RETRY` setting is enabled
+
+        """
+        conn = self.app.connection(heartbeat=self.amqheartbeat)
+
+        # Callback called for each retry while the connection
+        # can't be established.
+        def _error_handler(exc, interval, next_step=CONNECTION_RETRY_STEP):
+            if getattr(conn, 'alt', None) and interval == 0:
+                next_step = CONNECTION_FAILOVER
+            error(CONNECTION_ERROR, conn.as_uri(), exc,
+                  next_step.format(when=humanize_seconds(interval, 'in', ' ')))
+
+        # remember that the connection is lazy, it won't establish
+        # until it's needed.
+        if not self.app.conf.BROKER_CONNECTION_RETRY:
+            # retry disabled, just call connect directly.
+            conn.connect()
+            return conn
+
+        return conn.ensure_connection(_error_handler,
+                    self.app.conf.BROKER_CONNECTION_MAX_RETRIES,
+                    callback=self.maybe_shutdown)
+
+    def stop(self):
+        """Stop consuming.
+
+        Does not close the broker connection, so be sure to call
+        :meth:`close_connection` when you are finished with it.
+
+        """
+        self.namespace.stop(self)
+
+    def maybe_shutdown(self):
+        if state.should_stop:
+            raise SystemExit()
+        elif state.should_terminate:
+            raise SystemTerminate()
+
+    def add_task_queue(self, queue, exchange=None, exchange_type=None,
+            routing_key=None, **options):
+        cset = self.task_consumer
+        try:
+            q = self.app.amqp.queues[queue]
+        except KeyError:
+            exchange = queue if exchange is None else exchange
+            exchange_type = 'direct' if exchange_type is None \
+                                     else exchange_type
+            q = self.app.amqp.queues.select_add(queue,
+                    exchange=exchange,
+                    exchange_type=exchange_type,
+                    routing_key=routing_key, **options)
+        if not cset.consuming_from(queue):
+            cset.add_queue(q)
+            cset.consume()
+            info('Started consuming from %r', queue)
+
+    def cancel_task_queue(self, queue):
+        self.app.amqp.queues.select_remove(queue)
+        self.task_consumer.cancel_by_queue(queue)
+
+    @property
+    def info(self):
+        """Returns information about this consumer instance
+        as a dict.
+
+        This is also the consumer related info returned by
+        ``celeryctl stats``.
+
+        """
+        conninfo = {}
+        if self.connection:
+            conninfo = self.connection.info()
+            conninfo.pop('password', None)  # don't send password.
+        return {'broker': conninfo, 'prefetch_count': self.qos.value}
+
     def consume_messages(self, sleep=sleep, min=min, Empty=Empty,
     def consume_messages(self, sleep=sleep, min=min, Empty=Empty,
             hbrate=AMQHEARTBEAT_RATE):
             hbrate=AMQHEARTBEAT_RATE):
         """Consume messages forever (or until an exception is raised)."""
         """Consume messages forever (or until an exception is raised)."""
 
 
         with self.hub as hub:
         with self.hub as hub:
+            ns = self.namespace
             qos = self.qos
             qos = self.qos
             update_qos = qos.update
             update_qos = qos.update
             update_readers = hub.update_readers
             update_readers = hub.update_readers
@@ -318,7 +404,7 @@ class Consumer(object):
 
 
             debug('Ready to accept tasks!')
             debug('Ready to accept tasks!')
 
 
-            while self._state != CLOSE and self.connection:
+            while ns.state != CLOSE and self.connection:
                 # shutdown if signal handlers told us to.
                 # shutdown if signal handlers told us to.
                 if state.should_stop:
                 if state.should_stop:
                     raise SystemExit()
                     raise SystemExit()
@@ -360,7 +446,7 @@ class Consumer(object):
                             except (KeyError, Empty):
                             except (KeyError, Empty):
                                 continue
                                 continue
                             except socket.error:
                             except socket.error:
-                                if self._state != CLOSE:  # pragma: no cover
+                                if ns.state != CLOSE:  # pragma: no cover
                                     raise
                                     raise
                         if keep_draining:
                         if keep_draining:
                             drain_nowait()
                             drain_nowait()
@@ -394,7 +480,7 @@ class Consumer(object):
 
 
         if task.eta:
         if task.eta:
             try:
             try:
-                eta = timer2.to_timestamp(task.eta)
+                eta = to_timestamp(task.eta)
             except OverflowError as exc:
             except OverflowError as exc:
                 error("Couldn't convert eta %s to timestamp: %r. Task: %r",
                 error("Couldn't convert eta %s to timestamp: %r. Task: %r",
                       task.eta, exc, task.info(safe=True), exc_info=True)
                       task.eta, exc, task.info(safe=True), exc_info=True)
@@ -407,16 +493,6 @@ class Consumer(object):
             task_reserved(task)
             task_reserved(task)
             self._quick_put(task)
             self._quick_put(task)
 
 
-    def on_control(self, body, message):
-        """Process remote control command message."""
-        try:
-            self.pidbox_node.handle_message(body, message)
-        except KeyError as exc:
-            error('No such control command: %s', exc)
-        except Exception as exc:
-            error('Control command error: %r', exc, exc_info=True)
-            self.reset_pidbox_node()
-
     def apply_eta_task(self, task):
     def apply_eta_task(self, task):
         """Method called by the timer to apply a task with an
         """Method called by the timer to apply a task with an
         ETA/countdown."""
         ETA/countdown."""
@@ -442,288 +518,14 @@ class Consumer(object):
         error(INVALID_TASK_ERROR, exc, dump_body(message, body), exc_info=True)
         error(INVALID_TASK_ERROR, exc, dump_body(message, body), exc_info=True)
         message.reject_log_error(logger, self.connection_errors)
         message.reject_log_error(logger, self.connection_errors)
 
 
-    def receive_message(self, body, message):
-        """Handles incoming messages.
-
-        :param body: The message body.
-        :param message: The kombu message object.
-
-        """
-        try:
-            name = body['task']
-        except (KeyError, TypeError):
-            return self.handle_unknown_message(body, message)
-
-        try:
-            self.strategies[name](message, body, message.ack_log_error)
-        except KeyError as exc:
-            self.handle_unknown_task(body, message, exc)
-        except InvalidTaskError as exc:
-            self.handle_invalid_task(body, message, exc)
-
-    def maybe_conn_error(self, fun):
-        """Applies function but ignores any connection or channel
-        errors raised."""
-        try:
-            fun()
-        except (AttributeError, ) + \
-                self.connection_errors + \
-                self.channel_errors:
-            pass
-
-    def close_connection(self):
-        """Closes the current broker connection and all open channels."""
-
-        # We must set self.connection to None here, so
-        # that the green pidbox thread exits.
-        connection, self.connection = self.connection, None
-
-        if self.task_consumer:
-            debug('Closing consumer channel...')
-            self.task_consumer = \
-                    self.maybe_conn_error(self.task_consumer.close)
-
-        self.stop_pidbox_node()
-
-        if connection:
-            debug('Closing broker connection...')
-            self.maybe_conn_error(connection.close)
-
-    def stop_consumers(self, close_connection=True, join=True):
-        """Stop consuming tasks and broadcast commands, also stops
-        the heartbeat thread and event dispatcher.
-
-        :keyword close_connection: Set to False to skip closing the broker
-                                    connection.
-
-        """
-        if not self._state == RUN:
-            return
-
-        if self.heart:
-            # Stop the heartbeat thread if it's running.
-            debug('Heart: Going into cardiac arrest...')
-            self.heart = self.heart.stop()
-
-        debug('Cancelling task consumer...')
-        if join and self.task_consumer:
-            self.maybe_conn_error(self.task_consumer.cancel)
-
-        if self.event_dispatcher:
-            debug('Shutting down event dispatcher...')
-            self.event_dispatcher = \
-                    self.maybe_conn_error(self.event_dispatcher.close)
-
-        debug('Cancelling broadcast consumer...')
-        if join and self.broadcast_consumer:
-            self.maybe_conn_error(self.broadcast_consumer.cancel)
-
-        if close_connection:
-            self.close_connection()
-
-    def on_decode_error(self, message, exc):
-        """Callback called if an error occurs while decoding
-        a message received.
-
-        Simply logs the error and acknowledges the message so it
-        doesn't enter a loop.
-
-        :param message: The message with errors.
-        :param exc: The original exception instance.
-
-        """
-        crit("Can't decode message body: %r (type:%r encoding:%r raw:%r')",
-             exc, message.content_type, message.content_encoding,
-             dump_body(message, message.body))
-        message.ack()
-
-    def reset_pidbox_node(self):
-        """Sets up the process mailbox."""
-        self.stop_pidbox_node()
-        # close previously opened channel if any.
-        if self.pidbox_node.channel:
-            try:
-                self.pidbox_node.channel.close()
-            except self.connection_errors + self.channel_errors:
-                pass
-
-        if self.pool is not None and self.pool.is_green:
-            return self.pool.spawn_n(self._green_pidbox_node)
-        self.pidbox_node.channel = self.connection.channel()
-        self.broadcast_consumer = self.pidbox_node.listen(
-                                        callback=self.on_control)
-
-    def stop_pidbox_node(self):
-        if self._pidbox_node_stopped:
-            self._pidbox_node_shutdown.set()
-            debug('Waiting for broadcast thread to shutdown...')
-            self._pidbox_node_stopped.wait()
-            self._pidbox_node_stopped = self._pidbox_node_shutdown = None
-        elif self.broadcast_consumer:
-            debug('Closing broadcast channel...')
-            self.broadcast_consumer = \
-                self.maybe_conn_error(self.broadcast_consumer.channel.close)
-
-    def _green_pidbox_node(self):
-        """Sets up the process mailbox when running in a greenlet
-        environment."""
-        # THIS CODE IS TERRIBLE
-        # Luckily work has already started rewriting the Consumer for 4.0.
-        self._pidbox_node_shutdown = threading.Event()
-        self._pidbox_node_stopped = threading.Event()
-        try:
-            with self._open_connection() as conn:
-                info('pidbox: Connected to %s.', conn.as_uri())
-                self.pidbox_node.channel = conn.default_channel
-                self.broadcast_consumer = self.pidbox_node.listen(
-                                            callback=self.on_control)
-                with self.broadcast_consumer:
-                    while not self._pidbox_node_shutdown.isSet():
-                        try:
-                            conn.drain_events(timeout=1.0)
-                        except socket.timeout:
-                            pass
-        finally:
-            self._pidbox_node_stopped.set()
-
-    def reset_connection(self):
-        """Re-establish the broker connection and set up consumers,
-        heartbeat and the event dispatcher."""
-        debug('Re-establishing connection to the broker...')
-        self.stop_consumers(join=False)
-
-        # Clear internal queues to get rid of old messages.
-        # They can't be acked anyway, as a delivery tag is specific
-        # to the current channel.
-        self.ready_queue.clear()
-        self.timer.clear()
-
-        # Re-establish the broker connection and setup the task consumer.
-        self.connection = self._open_connection()
-        info('consumer: Connected to %s.', self.connection.as_uri())
-        self.task_consumer = self.app.amqp.TaskConsumer(self.connection,
-                                    on_decode_error=self.on_decode_error)
-        # QoS: Reset prefetch window.
-        self.qos = QoS(self.task_consumer, self.initial_prefetch_count)
-        self.qos.update()
-
-        # Setup the process mailbox.
-        self.reset_pidbox_node()
-
-        # Flush events sent while connection was down.
-        prev_event_dispatcher = self.event_dispatcher
-        self.event_dispatcher = self.app.events.Dispatcher(self.connection,
-                                                hostname=self.hostname,
-                                                enabled=self.send_events)
-        if prev_event_dispatcher:
-            self.event_dispatcher.copy_buffer(prev_event_dispatcher)
-            self.event_dispatcher.flush()
-
-        # Restart heartbeat thread.
-        self.restart_heartbeat()
-
-        # reload all task's execution strategies.
-        self.update_strategies()
-
-        # We're back!
-        self._state = RUN
-
-    def restart_heartbeat(self):
-        """Restart the heartbeat thread.
-
-        This thread sends heartbeat events at intervals so monitors
-        can tell if the worker is off-line/missing.
-
-        """
-        self.heart = Heart(self.timer, self.event_dispatcher)
-        self.heart.start()
-
-    def _open_connection(self):
-        """Establish the broker connection.
-
-        Will retry establishing the connection if the
-        :setting:`BROKER_CONNECTION_RETRY` setting is enabled
-
-        """
-        conn = self.app.connection(heartbeat=self.amqheartbeat)
-
-        # Callback called for each retry while the connection
-        # can't be established.
-        def _error_handler(exc, interval, next_step=CONNECTION_RETRY):
-            if getattr(conn, 'alt', None) and interval == 0:
-                next_step = CONNECTION_FAILOVER
-            error(CONNECTION_ERROR, conn.as_uri(), exc,
-                  next_step.format(when=humanize_seconds(interval, 'in', ' ')))
-
-        # remember that the connection is lazy, it won't establish
-        # until it's needed.
-        if not self.app.conf.BROKER_CONNECTION_RETRY:
-            # retry disabled, just call connect directly.
-            conn.connect()
-            return conn
-
-        return conn.ensure_connection(_error_handler,
-                    self.app.conf.BROKER_CONNECTION_MAX_RETRIES,
-                    callback=self.maybe_shutdown)
-
-    def stop(self):
-        """Stop consuming.
-
-        Does not close the broker connection, so be sure to call
-        :meth:`close_connection` when you are finished with it.
-
-        """
-        # Notifies other threads that this instance can't be used
-        # anymore.
-        self.close()
-        debug('Stopping consumers...')
-        self.stop_consumers(close_connection=False, join=True)
-
-    def close(self):
-        self._state = CLOSE
-
-    def maybe_shutdown(self):
-        if state.should_stop:
-            raise SystemExit()
-        elif state.should_terminate:
-            raise SystemTerminate()
-
-    def add_task_queue(self, queue, exchange=None, exchange_type=None,
-            routing_key=None, **options):
-        cset = self.task_consumer
-        try:
-            q = self.app.amqp.queues[queue]
-        except KeyError:
-            exchange = queue if exchange is None else exchange
-            exchange_type = 'direct' if exchange_type is None \
-                                     else exchange_type
-            q = self.app.amqp.queues.select_add(queue,
-                    exchange=exchange,
-                    exchange_type=exchange_type,
-                    routing_key=routing_key, **options)
-        if not cset.consuming_from(queue):
-            cset.add_queue(q)
-            cset.consume()
-            logger.info('Started consuming from %r', queue)
-
-    def cancel_task_queue(self, queue):
-        self.app.amqp.queues.select_remove(queue)
-        self.task_consumer.cancel_by_queue(queue)
-
-    @property
-    def info(self):
-        """Returns information about this consumer instance
-        as a dict.
-
-        This is also the consumer related info returned by
-        ``celeryctl stats``.
-
-        """
-        conninfo = {}
-        if self.connection:
-            conninfo = self.connection.info()
-            conninfo.pop('password', None)  # don't send password.
-        return {'broker': conninfo, 'prefetch_count': self.qos.value}
+    def update_strategies(self):
+        S = self.strategies
+        app = self.app
+        loader = app.loader
+        hostname = self.hostname
+        for name, task in self.app.tasks.iteritems():
+            S[name] = task.start_strategy(app, self)
+            task.__trace__ = build_tracer(name, task, loader, hostname)
 
 
 
 
 class BlockingConsumer(Consumer):
 class BlockingConsumer(Consumer):
@@ -734,8 +536,9 @@ class BlockingConsumer(Consumer):
         self.task_consumer.consume()
         self.task_consumer.consume()
 
 
         debug('Ready to accept tasks!')
         debug('Ready to accept tasks!')
+        ns = self.ns
 
 
-        while self._state != CLOSE and self.connection:
+        while ns.state != CLOSE and self.connection:
             self.maybe_shutdown()
             self.maybe_shutdown()
             if self.qos.prev != self.qos.value:     # pragma: no cover
             if self.qos.prev != self.qos.value:     # pragma: no cover
                 self.qos.update()
                 self.qos.update()
@@ -744,5 +547,24 @@ class BlockingConsumer(Consumer):
             except socket.timeout:
             except socket.timeout:
                 pass
                 pass
             except socket.error:
             except socket.error:
-                if self._state != CLOSE:            # pragma: no cover
+                if ns.state != CLOSE:            # pragma: no cover
                     raise
                     raise
+
+    def receive_message(self, body, message):
+        """Handles incoming messages.
+
+        :param body: The message body.
+        :param message: The kombu message object.
+
+        """
+        try:
+            name = body['task']
+        except (KeyError, TypeError):
+            return self.handle_unknown_message(body, message)
+
+        try:
+            self.strategies[name](message, body, message.ack_log_error)
+        except KeyError as exc:
+            self.handle_unknown_task(body, message, exc)
+        except InvalidTaskError as exc:
+            self.handle_invalid_task(body, message, exc)

+ 213 - 0
celery/worker/parts.py

@@ -0,0 +1,213 @@
+from __future__ import absolute_import
+
+import socket
+import threading
+
+from kombu.common import QoS
+
+from celery.datastructures import AttributeDict
+from celery.utils.log import get_logger
+
+from .bootsteps import StartStopComponent
+from .control import Panel
+from .heartbeat import Heart
+
+logger = get_logger(__name__)
+info, error, debug = logger.info, logger.error, logger.debug
+
+
+class ConsumerConnection(StartStopComponent):
+    name = 'consumer.connection'
+    delay_shutdown = True
+
+    def __init__(self, c, **kwargs):
+        c.connection = None
+
+    def start(self, c):
+        debug('  Re-establishing connection to the broker...')
+        c.connection = c._open_connection()
+        # Re-establish the broker connection and setup the task consumer.
+        info('  consumer: Connected to %s.', c.connection.as_uri())
+
+    def stop(self, c):
+        pass
+
+    def shutdown(self, c):
+        # We must set self.connection to None here, so
+        # that the green pidbox thread exits.
+        connection, c.connection = c.connection, None
+
+        if connection:
+            c.maybe_conn_error(connection.close)
+
+
+class Events(StartStopComponent):
+    name = 'consumer.events'
+    requires = (ConsumerConnection, )
+
+    def __init__(self, c, send_events=None, **kwargs):
+        self.app = c.app
+        c.event_dispatcher = None
+        self.send_events = send_events
+
+    def start(self, c):
+        # Flush events sent while connection was down.
+        prev_event_dispatcher = c.event_dispatcher
+        c.event_dispatcher = self.app.events.Dispatcher(c.connection,
+                                                hostname=c.hostname,
+                                                enabled=self.send_events)
+        if prev_event_dispatcher:
+            c.event_dispatcher.copy_buffer(prev_event_dispatcher)
+            c.event_dispatcher.flush()
+
+    def stop(self, c):
+        if c.event_dispatcher:
+            debug('  Shutting down event dispatcher...')
+            c.event_dispatcher = \
+                    c.maybe_conn_error(c.event_dispatcher.close)
+
+    def shutdown(self, c):
+        pass
+
+
+class Heartbeat(StartStopComponent):
+    name = 'consumer.heart'
+    requires = (Events, )
+
+    def __init__(self, c, **kwargs):
+        c.heart = None
+
+    def start(self, c):
+        c.heart = Heart(c.timer, c.event_dispatcher)
+        c.heart.start()
+
+    def stop(self, c):
+        if c.heart:
+            # Stop the heartbeat thread if it's running.
+            debug('  Heart: Going into cardiac arrest...')
+            c.heart = c.heart.stop()
+
+    def shutdown(self, c):
+        pass
+
+
+class Controller(StartStopComponent):
+    name = 'consumer.controller'
+    requires = (Events, )
+
+    _pidbox_node_shutdown = None
+    _pidbox_node_stopped = None
+
+    def __init__(self, c, **kwargs):
+        self.app = c.app
+        self.pool = c.pool
+        pidbox_state = AttributeDict(
+            app=c.app, hostname=c.hostname, consumer=c,
+        )
+        self.pidbox_node = self.app.control.mailbox.Node(
+            c.hostname, state=pidbox_state, handlers=Panel.data,
+        )
+        self.broadcast_consumer = None
+        self.consumer = c
+
+    def start(self, c):
+        self.reset_pidbox_node()
+
+    def stop(self, c):
+        pass
+
+    def shutdown(self, c):
+        self.stop_pidbox_node()
+        if self.broadcast_consumer:
+            debug('  Cancelling broadcast consumer...')
+            c.maybe_conn_error(self.broadcast_consumer.cancel)
+            self.broadcast_consumer = None
+
+    def on_control(self, body, message):
+        """Process remote control command message."""
+        try:
+            self.pidbox_node.handle_message(body, message)
+        except KeyError as exc:
+            error('No such control command: %s', exc)
+        except Exception as exc:
+            error('Control command error: %r', exc, exc_info=True)
+            self.reset_pidbox_node()
+
+    def reset_pidbox_node(self):
+        """Sets up the process mailbox."""
+        c = self.consumer
+        self.stop_pidbox_node()
+        # close previously opened channel if any.
+        if self.pidbox_node and self.pidbox_node.channel:
+            c.maybe_conn_error(self.pidbox_node.channel.close)
+
+        if self.pool is not None and self.pool.is_green:
+            return self.pool.spawn_n(self._green_pidbox_node)
+        self.pidbox_node.channel = c.connection.channel()
+        self.broadcast_consumer = self.pidbox_node.listen(
+                                        callback=self.on_control)
+
+    def stop_pidbox_node(self):
+        c = self.consumer
+        if self._pidbox_node_stopped:
+            self._pidbox_node_shutdown.set()
+            debug('  Waiting for broadcast thread to shutdown...')
+            self._pidbox_node_stopped.wait()
+            self._pidbox_node_stopped = self._pidbox_node_shutdown = None
+        elif self.broadcast_consumer:
+            debug('  Closing broadcast channel...')
+            self.broadcast_consumer = \
+                c.maybe_conn_error(self.broadcast_consumer.channel.close)
+
+    def _green_pidbox_node(self):
+        """Sets up the process mailbox when running in a greenlet
+        environment."""
+        # THIS CODE IS TERRIBLE
+        # Luckily work has already started rewriting the Consumer for 4.0.
+        self._pidbox_node_shutdown = threading.Event()
+        self._pidbox_node_stopped = threading.Event()
+        c = self.consumer
+        try:
+            with c._open_connection() as conn:
+                info('pidbox: Connected to %s.', conn.as_uri())
+                self.pidbox_node.channel = conn.default_channel
+                self.broadcast_consumer = self.pidbox_node.listen(
+                                            callback=self.on_control)
+                with self.broadcast_consumer:
+                    while not self._pidbox_node_shutdown.isSet():
+                        try:
+                            conn.drain_events(timeout=1.0)
+                        except socket.timeout:
+                            pass
+        finally:
+            self._pidbox_node_stopped.set()
+
+
+class Tasks(StartStopComponent):
+    name = 'consumer.tasks'
+    requires = (Controller, )
+    last = True
+
+    def __init__(self, c, initial_prefetch_count=2, **kwargs):
+        self.app = c.app
+        c.task_consumer = None
+        c.qos = None
+        self.initial_prefetch_count = initial_prefetch_count
+
+    def start(self, c):
+        c.task_consumer = self.app.amqp.TaskConsumer(c.connection,
+                                    on_decode_error=c.on_decode_error)
+        # QoS: Reset prefetch window.
+        c.qos = QoS(c.task_consumer, self.initial_prefetch_count)
+        c.qos.update()
+
+    def stop(self, c):
+        pass
+
+    def shutdown(self, c):
+        if c.task_consumer:
+            debug('  Cancelling task consumer...')
+            c.maybe_conn_error(c.task_consumer.cancel)
+            debug('  Closing consumer channel...')
+            c.maybe_conn_error(c.task_consumer.close)
+            c.task_consumer = None