Selaa lähdekoodia

Remote control commands now using kombu.pidbox

Ask Solem 14 vuotta sitten
vanhempi
commit
5eb20d6dd0

+ 0 - 149
celery/pidbox.py

@@ -1,149 +0,0 @@
-import socket
-import warnings
-
-from itertools import count
-
-from kombu.entity import Exchange, Queue
-from kombu.messaging import Consumer, Producer
-
-from celery.app import app_or_default
-from celery.utils import gen_unique_id
-
-
-
-class Mailbox(object):
-
-    def __init__(self, namespace, connection):
-        self.namespace = namespace
-        self.connection = connection
-        self.exchange = Exchange("%s.pidbox" % (self.namespace, ),
-                                 type="fanout",
-                                 durable=False,
-                                 auto_delete=True,
-                                 delivery_mode="transient")
-        self.reply_exchange = Exchange("reply.%s.pidbox" % (self.namespace, ),
-                                       type="direct",
-                                       durable=False,
-                                       auto_delete=True,
-                                       delivery_mode="transient")
-
-    def publish_reply(self, reply, exchange, routing_key, channel=None):
-        chan = channel or self.connection.channel()
-        try:
-            exchange = Exchange(exchange, exchange_type="direct",
-                                          delivery_mode="transient",
-                                          durable=False,
-                                          auto_delete=True)
-            producer = Producer(chan, exchange=exchange)
-            producer.publish(reply, routing_key=routing_key)
-        finally:
-            channel or chan.close()
-
-    def get_reply_queue(self, ticket):
-        return Queue("%s.%s" % (ticket, self.reply_exchange.name),
-                     exchange=self.reply_exchange,
-                     routing_key=ticket,
-                     durable=False,
-                     auto_delete=True)
-
-    def get_queue(self, hostname):
-        return Queue("%s.%s.pidbox" % (hostname, self.namespace),
-                     exchange=self.exchange)
-
-    def collect_reply(self, ticket, limit=None, timeout=1,
-            callback=None, channel=None):
-        chan = channel or self.connection.channel()
-        queue = self.get_reply_queue(ticket)
-        consumer = Consumer(channel, [queue], no_ack=True)
-        responses = []
-
-        def on_message(message_data, message):
-            if callback:
-                callback(message_data)
-            responses.append(message_data)
-
-        try:
-            consumer.register_callback(on_message)
-            consumer.consume()
-            for i in limit and range(limit) or count():
-                try:
-                    self.connection.drain_events(timeout=timeout)
-                except socket.timeout:
-                    break
-            return responses
-        finally:
-            channel or chan.close()
-
-    def publish(self, type, arguments, destination=None, reply_ticket=None,
-            channel=None):
-        arguments["command"] = type
-        arguments["destination"] = destination
-        if reply_ticket:
-            arguments["reply_to"] = {"exchange": self.reply_exchange.name,
-                                     "routing_key": reply_ticket}
-        chan = channel or self.connection.channel()
-        producer = Producer(chan, exchange=self.exchange)
-        try:
-            producer.publish({"control": arguments})
-        finally:
-            channel or chan.close()
-
-    def Node(self, hostname, channel=None):
-        return Consumer(channel or self.connection.channel(),
-                        [self.get_queue(hostname)],
-                        no_ack=True)
-
-    def call(self, destination, command, kwargs={}, timeout=None,
-            callback=None, channel=None):
-        return self._broadcast(command, kwargs, destination,
-                               reply=True, timeout=timeout,
-                               callback=callback,
-                               channel=channel)
-
-    def cast(self, destination, command, kwargs={}):
-        return self._broadcast(command, kwargs, destination, reply=False)
-
-    def abcast(self, command, kwargs={}):
-        return self._broadcast(command, kwargs, reply=False)
-
-    def multi_call(self, command, kwargs={}, timeout=1,
-            limit=None, callback=None, channel=None):
-        return self._broadcast(command, kwargs, reply=True,
-                               timeout=timeout, limit=limit,
-                               callback=callback,
-                               channel=channel)
-
-    def _broadcast(self, command, arguments=None, destination=None,
-            reply=False, timeout=1, limit=None, callback=None, channel=None):
-        arguments = arguments or {}
-        reply_ticket = reply and gen_unique_id() or None
-
-        if destination is not None and \
-                not isinstance(destination, (list, tuple)):
-            raise ValueError("destination must be a list/tuple not %s" % (
-                    type(destination)))
-
-        # Set reply limit to number of destinations (if specificed)
-        if limit is None and destination:
-            limit = destination and len(destination) or None
-
-        chan = channel or self.connection.channel()
-        try:
-            if reply_ticket:
-                self.get_reply_queue(reply_ticket)(chan).declare()
-
-            self.publish(command, arguments, destination=destination,
-                                             reply_ticket=reply_ticket,
-                                             channel=chan)
-
-            if reply_ticket:
-                return self.collect_reply(reply_ticket, limit=limit,
-                                                        timeout=timeout,
-                                                        callback=callback,
-                                                        channel=chan)
-        finally:
-            channel or chan.close()
-
-
-def mailbox(connection):
-    return Mailbox("celeryd", connection)

+ 8 - 4
celery/task/control.py

@@ -1,5 +1,6 @@
+from kombu.pidbox import Mailbox
+
 from celery.app import app_or_default
-from celery.pidbox import mailbox
 from celery.utils import gen_unique_id
 
 
@@ -72,9 +73,11 @@ class Inspect(object):
 
 
 class Control(object):
+    Mailbox = Mailbox
 
     def __init__(self, app):
         self.app = app
+        self.mailbox = self.Mailbox("celeryd", type="fanout")
 
     def inspect(self, destination=None, timeout=1, callback=None):
         return Inspect(self, destination=destination, timeout=timeout,
@@ -185,9 +188,10 @@ class Control(object):
 
         """
         def _do_broadcast(connection=None, connect_timeout=None):
-            return mailbox(connection)._broadcast(command, arguments,
-                                                  destination, reply,
-                                                  timeout, limit, callback)
+            return self.mailbox(connection)._broadcast(command, arguments,
+                                                       destination, reply,
+                                                       timeout, limit,
+                                                       callback)
 
         return self.app.with_default_connection(_do_broadcast)(
                 connection=connection, connect_timeout=connect_timeout)

+ 33 - 23
celery/tests/test_task_control.py

@@ -1,6 +1,8 @@
 import unittest2 as unittest
 
-from celery.pidbox import Mailbox
+from kombu.pidbox import Mailbox
+
+from celery.app import app_or_default
 from celery.task import control
 from celery.task.builtins import PingTask
 from celery.utils import gen_unique_id
@@ -10,38 +12,38 @@ from celery.utils.functional import wraps
 class MockMailbox(Mailbox):
     sent = []
 
-    def publish(self, command, *args, **kwargs):
+    def _publish(self, command, *args, **kwargs):
         self.__class__.sent.append(command)
 
     def close(self):
         pass
 
-    def collect_reply(self, *args, **kwargs):
+    def _collect(self, *args, **kwargs):
         pass
 
 
-def mock_mailbox(connection):
-    return MockMailbox("celeryd", connection)
+class Control(control.Control):
+    Mailbox = MockMailbox
+
 
 
 def with_mock_broadcast(fun):
 
     @wraps(fun)
-    def _mocked(*args, **kwargs):
-        old_box = control.mailbox
-        control.mailbox = mock_mailbox
+    def _resets(*args, **kwargs):
+        MockMailbox.sent = []
         try:
             return fun(*args, **kwargs)
         finally:
             MockMailbox.sent = []
-            control.mailbox = old_box
-    return _mocked
+    return _resets
 
 
 class test_inspect(unittest.TestCase):
 
     def setUp(self):
-        self.i = control.inspect()
+        app = app_or_default()
+        self.i = Control(app=app).inspect()
 
     def test_prepare_reply(self):
         self.assertDictEqual(self.i._prepare([{"w1": {"ok": 1}},
@@ -100,50 +102,58 @@ class test_inspect(unittest.TestCase):
 
 class test_Broadcast(unittest.TestCase):
 
+    def setUp(self):
+        self.app = app_or_default()
+        self.control = Control(app=self.app)
+        self.app._control = self.control
+
+    def tearDown(self):
+        self.app._control = None
+
     def test_discard_all(self):
-        control.discard_all()
+        self.control.discard_all()
 
     @with_mock_broadcast
     def test_broadcast(self):
-        control.broadcast("foobarbaz", arguments=[])
+        self.control.broadcast("foobarbaz", arguments=[])
         self.assertIn("foobarbaz", MockMailbox.sent)
 
     @with_mock_broadcast
     def test_broadcast_limit(self):
-        control.broadcast("foobarbaz1", arguments=[], limit=None,
+        self.control.broadcast("foobarbaz1", arguments=[], limit=None,
                 destination=[1, 2, 3])
         self.assertIn("foobarbaz1", MockMailbox.sent)
 
     @with_mock_broadcast
     def test_broadcast_validate(self):
-        self.assertRaises(ValueError, control.broadcast, "foobarbaz2",
+        self.assertRaises(ValueError, self.control.broadcast, "foobarbaz2",
                           destination="foo")
 
     @with_mock_broadcast
     def test_rate_limit(self):
-        control.rate_limit(PingTask.name, "100/m")
+        self.control.rate_limit(PingTask.name, "100/m")
         self.assertIn("rate_limit", MockMailbox.sent)
 
     @with_mock_broadcast
     def test_revoke(self):
-        control.revoke("foozbaaz")
+        self.control.revoke("foozbaaz")
         self.assertIn("revoke", MockMailbox.sent)
 
     @with_mock_broadcast
     def test_ping(self):
-        control.ping()
+        self.control.ping()
         self.assertIn("ping", MockMailbox.sent)
 
     @with_mock_broadcast
     def test_revoke_from_result(self):
-        from celery.result import AsyncResult
-        AsyncResult("foozbazzbar").revoke()
+        self.app.AsyncResult("foozbazzbar").revoke()
         self.assertIn("revoke", MockMailbox.sent)
 
     @with_mock_broadcast
     def test_revoke_from_resultset(self):
-        from celery.result import TaskSetResult, AsyncResult
-        r = TaskSetResult(gen_unique_id(), map(AsyncResult, [gen_unique_id()
-                                                        for i in range(10)]))
+        r = self.app.TaskSetResult(gen_unique_id(),
+                                   map(self.app.AsyncResult,
+                                        [gen_unique_id()
+                                            for i in range(10)]))
         r.revoke()
         self.assertIn("revoke", MockMailbox.sent)

+ 6 - 20
celery/tests/test_worker.py

@@ -54,10 +54,10 @@ class MyKombuConsumer(MainConsumer):
         self.heart = None
 
 
-class MockControlDispatch(object):
+class MockNode(object):
     commands = []
 
-    def dispatch_from_message(self, message):
+    def handle_message(self, message_data, message):
         self.commands.append(message.pop("command", None))
 
 
@@ -218,16 +218,6 @@ class test_Consumer(unittest.TestCase):
         self.assertIsNone(l.connection)
         self.assertIsNone(l.task_consumer)
 
-    def test_receive_message_control_command(self):
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
-                           send_events=False)
-        backend = MockBackend()
-        m = create_message(backend, control={"command": "shutdown"})
-        l.event_dispatcher = MockEventDispatcher()
-        l.control_dispatch = MockControlDispatch()
-        l.receive_message(m.decode(), m)
-        self.assertIn("shutdown", l.control_dispatch.commands)
-
     def test_close_connection(self):
         l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                            send_events=False)
@@ -249,7 +239,7 @@ class test_Consumer(unittest.TestCase):
         backend = MockBackend()
         m = create_message(backend, unknown={"baz": "!!!"})
         l.event_dispatcher = MockEventDispatcher()
-        l.control_dispatch = MockControlDispatch()
+        l.pidbox_node = MockNode()
 
         def with_catch_warnings(log):
             l.receive_message(m.decode(), m)
@@ -274,7 +264,7 @@ class test_Consumer(unittest.TestCase):
                                     kwargs={},
                                     eta=datetime.now().isoformat())
         l.event_dispatcher = MockEventDispatcher()
-        l.control_dispatch = MockControlDispatch()
+        l.pidbox_node = MockNode()
 
         prev, consumer.to_timestamp = consumer.to_timestamp, to_timestamp
         try:
@@ -292,7 +282,7 @@ class test_Consumer(unittest.TestCase):
         m = create_message(backend, task=foo_task.name,
             args=(1, 2), kwargs="foobarbaz", id=1)
         l.event_dispatcher = MockEventDispatcher()
-        l.control_dispatch = MockControlDispatch()
+        l.pidbox_node = MockNode()
 
         l.receive_message(m.decode(), m)
         self.assertIn("Invalid task ignored", logger.logged[0])
@@ -368,14 +358,10 @@ class test_Consumer(unittest.TestCase):
                            send_events=False)
         backend = MockBackend()
         id = gen_unique_id()
-        c = create_message(backend, control={"command": "revoke",
-                                             "task_id": id})
         t = create_message(backend, task=foo_task.name, args=[2, 4, 8],
                            kwargs={}, id=id)
-        l.event_dispatcher = MockEventDispatcher()
-        l.receive_message(c.decode(), c)
         from celery.worker.state import revoked
-        self.assertIn(id, revoked)
+        revoked.add(id)
 
         l.receive_message(t.decode(), t)
         self.assertTrue(ready_queue.empty())

+ 52 - 36
celery/tests/test_worker_control.py

@@ -1,17 +1,22 @@
 import socket
 import unittest2 as unittest
 
+from datetime import datetime, timedelta
+
+from kombu import pidbox
+
 from celery.utils.timer2 import Timer
 
 from celery.app import app_or_default
+from celery.datastructures import AttributeDict
 from celery.decorators import task
 from celery.registry import tasks
 from celery.task.builtins import PingTask
 from celery.utils import gen_unique_id
-from celery.worker import control
 from celery.worker.buckets import FastQueue
 from celery.worker.job import TaskRequest
 from celery.worker.state import revoked
+from celery.worker.control.registry import Panel
 
 hostname = socket.gethostname()
 
@@ -53,16 +58,23 @@ class Consumer(object):
 class test_ControlPanel(unittest.TestCase):
 
     def setUp(self):
+        self.app = app_or_default()
         self.panel = self.create_panel(consumer=Consumer())
 
+    def create_state(self, **kwargs):
+        kwargs.setdefault("logger", self.app.log.get_default_logger())
+        return AttributeDict(kwargs)
+
     def create_panel(self, **kwargs):
-        return control.ControlDispatch(hostname=hostname, **kwargs)
+        return self.app.control.mailbox.Node(hostname=hostname,
+                                             state=self.create_state(**kwargs),
+                                             handlers=Panel.data)
 
     def test_disable_events(self):
         consumer = Consumer()
         panel = self.create_panel(consumer=consumer)
         consumer.event_dispatcher.enabled = True
-        panel.execute("disable_events")
+        panel.handle("disable_events")
         self.assertEqual(consumer.event_dispatcher.enabled, False)
         self.assertIn("worker-offline", consumer.event_dispatcher.sent)
 
@@ -70,47 +82,49 @@ class test_ControlPanel(unittest.TestCase):
         consumer = Consumer()
         panel = self.create_panel(consumer=consumer)
         consumer.event_dispatcher.enabled = False
-        panel.execute("enable_events")
+        panel.handle("enable_events")
         self.assertEqual(consumer.event_dispatcher.enabled, True)
         self.assertIn("worker-online", consumer.event_dispatcher.sent)
 
     def test_dump_tasks(self):
-        info = "\n".join(self.panel.execute("dump_tasks"))
+        info = "\n".join(self.panel.handle("dump_tasks"))
         self.assertIn("mytask", info)
         self.assertIn("rate_limit=200", info)
 
     def test_dump_schedule(self):
         consumer = Consumer()
         panel = self.create_panel(consumer=consumer)
-        self.assertFalse(panel.execute("dump_schedule"))
-        import operator
-        consumer.eta_schedule.schedule.enter(100, operator.add, (2, 2))
-        self.assertTrue(panel.execute("dump_schedule"))
+        self.assertFalse(panel.handle("dump_schedule"))
+        r = TaskRequest("celery.ping", "CAFEBABE", (), {})
+        consumer.eta_schedule.schedule.enter(
+                consumer.eta_schedule.Entry(lambda x: x, (r, )),
+                    datetime.now() + timedelta(seconds=10))
+        self.assertTrue(panel.handle("dump_schedule"))
 
     def test_dump_reserved(self):
         consumer = Consumer()
         panel = self.create_panel(consumer=consumer)
-        response = panel.execute("dump_reserved", {"safe": True})
+        response = panel.handle("dump_reserved", {"safe": True})
         self.assertDictContainsSubset({"name": mytask.name,
                                        "args": (2, 2),
                                        "kwargs": {},
                                        "hostname": socket.gethostname()},
                                        response[0])
         consumer.ready_queue = FastQueue()
-        self.assertFalse(panel.execute("dump_reserved"))
+        self.assertFalse(panel.handle("dump_reserved"))
 
     def test_rate_limit_when_disabled(self):
         app = app_or_default()
         app.conf.CELERY_DISABLE_RATE_LIMITS = True
         try:
-            e = self.panel.execute("rate_limit", kwargs=dict(
+            e = self.panel.handle("rate_limit", arguments=dict(
                  task_name=mytask.name, rate_limit="100/m"))
             self.assertIn("rate limits disabled", e.get("error"))
         finally:
             app.conf.CELERY_DISABLE_RATE_LIMITS = False
 
     def test_rate_limit_invalid_rate_limit_string(self):
-        e = self.panel.execute("rate_limit", kwargs=dict(
+        e = self.panel.handle("rate_limit", arguments=dict(
             task_name="tasks.add", rate_limit="x1240301#%!"))
         self.assertIn("Invalid rate limit string", e.get("error"))
 
@@ -133,66 +147,66 @@ class test_ControlPanel(unittest.TestCase):
         task = tasks[PingTask.name]
         old_rate_limit = task.rate_limit
         try:
-            panel.execute("rate_limit", kwargs=dict(task_name=task.name,
-                                                    rate_limit="100/m"))
+            panel.handle("rate_limit", arguments=dict(task_name=task.name,
+                                                      rate_limit="100/m"))
             self.assertEqual(task.rate_limit, "100/m")
             self.assertTrue(consumer.ready_queue.fresh)
             consumer.ready_queue.fresh = False
-            panel.execute("rate_limit", kwargs=dict(task_name=task.name,
-                                                    rate_limit=0))
+            panel.handle("rate_limit", arguments=dict(task_name=task.name,
+                                                      rate_limit=0))
             self.assertEqual(task.rate_limit, 0)
             self.assertTrue(consumer.ready_queue.fresh)
         finally:
             task.rate_limit = old_rate_limit
 
     def test_rate_limit_nonexistant_task(self):
-        self.panel.execute("rate_limit", kwargs={
+        self.panel.handle("rate_limit", arguments={
                                 "task_name": "xxxx.does.not.exist",
                                 "rate_limit": "1000/s"})
 
     def test_unexposed_command(self):
-        self.panel.execute("foo", kwargs={})
+        self.assertRaises(KeyError, self.panel.handle, "foo", arguments={})
 
     def test_revoke_with_name(self):
         uuid = gen_unique_id()
-        m = {"command": "revoke",
+        m = {"method": "revoke",
              "destination": hostname,
-             "task_id": uuid,
-             "task_name": mytask.name}
+             "arguments": {"task_id": uuid,
+                           "task_name": mytask.name}}
         self.panel.dispatch_from_message(m)
         self.assertIn(uuid, revoked)
 
     def test_revoke_with_name_not_in_registry(self):
         uuid = gen_unique_id()
-        m = {"command": "revoke",
+        m = {"method": "revoke",
              "destination": hostname,
-             "task_id": uuid,
-             "task_name": "xxxxxxxxx33333333388888"}
+             "arguments": {"task_id": uuid,
+                           "task_name": "xxxxxxxxx33333333388888"}}
         self.panel.dispatch_from_message(m)
         self.assertIn(uuid, revoked)
 
     def test_revoke(self):
         uuid = gen_unique_id()
-        m = {"command": "revoke",
+        m = {"method": "revoke",
              "destination": hostname,
-             "task_id": uuid}
+             "arguments": {"task_id": uuid}}
         self.panel.dispatch_from_message(m)
         self.assertIn(uuid, revoked)
 
-        m = {"command": "revoke",
+        m = {"method": "revoke",
              "destination": "does.not.exist",
-             "task_id": uuid + "xxx"}
+             "arguments": {"task_id": uuid + "xxx"}}
         self.panel.dispatch_from_message(m)
         self.assertNotIn(uuid + "xxx", revoked)
 
     def test_ping(self):
-        m = {"command": "ping",
+        m = {"method": "ping",
              "destination": hostname}
         r = self.panel.dispatch_from_message(m)
         self.assertEqual(r, "pong")
 
     def test_shutdown(self):
-        m = {"command": "shutdown",
+        m = {"method": "shutdown",
              "destination": hostname}
         self.assertRaises(SystemExit, self.panel.dispatch_from_message, m)
 
@@ -200,14 +214,16 @@ class test_ControlPanel(unittest.TestCase):
 
         replies = []
 
-        class _Dispatch(control.ControlDispatch):
+        class _Node(pidbox.Node):
 
             def reply(self, data, exchange, routing_key, **kwargs):
                 replies.append(data)
 
-        panel = _Dispatch(hostname, consumer=Consumer())
-
-        r = panel.execute("ping", reply_to={"exchange": "x",
-                                            "routing_key": "x"})
+        panel = _Node(hostname=hostname,
+                      state=self.create_state(consumer=Consumer()),
+                      handlers=Panel.data,
+                      mailbox=self.app.control.mailbox)
+        r = panel.dispatch("ping", reply_to={"exchange": "x",
+                                             "routing_key": "x"})
         self.assertEqual(r, "pong")
         self.assertDictEqual(replies[0], {panel.hostname: "pong"})

+ 1 - 20
celery/utils/__init__.py

@@ -4,18 +4,13 @@ import os
 import sys
 import time
 import operator
-try:
-    import ctypes
-except ImportError:
-    ctypes = None
 import importlib
 import logging
 
-from uuid import UUID, uuid4, _uuid_generate_random
 from inspect import getargspec
 from itertools import islice
 
-from kombu.utils import rpartition
+from kombu.utils import gen_unique_id, rpartition
 
 from celery.utils.functional import partial
 
@@ -163,20 +158,6 @@ def chunks(it, n):
         yield [first] + list(islice(it, n - 1))
 
 
-def gen_unique_id():
-    """Generate a unique id, having - hopefully - a very small chance of
-    collission.
-
-    For now this is provided by :func:`uuid.uuid4`.
-    """
-    # Workaround for http://bugs.python.org/issue4607
-    if ctypes and _uuid_generate_random:
-        buffer = ctypes.create_string_buffer(16)
-        _uuid_generate_random(buffer)
-        return str(UUID(bytes=buffer.raw))
-    return str(uuid4())
-
-
 def padlist(container, size, default=None):
     """Pad list with default elements.
 

+ 2 - 2
celery/utils/timer2.py

@@ -75,8 +75,8 @@ class Schedule(object):
         try:
             eta = to_timestamp(eta)
         except OverflowError:
-            self.handle_error(sys.exc_info())
-            return
+            if not self.handle_error(sys.exc_info()):
+                raise
 
         if eta is None:
             # schedule now.

+ 20 - 20
celery/worker/consumer.py

@@ -74,14 +74,13 @@ import socket
 import warnings
 
 from celery.app import app_or_default
-from celery.datastructures import SharedCounter
+from celery.datastructures import AttributeDict, SharedCounter
 from celery.events import EventDispatcher
 from celery.exceptions import NotRegistered
-from celery.pidbox import mailbox
 from celery.utils import noop
 from celery.utils.timer2 import to_timestamp
 from celery.worker.job import TaskRequest, InvalidTaskError
-from celery.worker.control import ControlDispatch
+from celery.worker.control.registry import Panel
 from celery.worker.heartbeat import Heart
 
 RUN = 0x1
@@ -210,10 +209,14 @@ class Consumer(object):
         self.event_dispatcher = None
         self.heart = None
         self.pool = pool
-        self.control_dispatch = ControlDispatch(app=self.app,
-                                                logger=logger,
-                                                hostname=self.hostname,
-                                                consumer=self)
+        pidbox_state = AttributeDict(app=self.app,
+                                     logger=logger,
+                                     hostname=self.hostname,
+                                     listener=self, # pre 2.2
+                                     consumer=self)
+        self.pidbox_node = self.app.control.mailbox.Node(self.hostname,
+                                                         state=pidbox_state,
+                                                         handlers=Panel.data)
         self.connection_errors = \
                 self.app.broker_connection().connection_errors
         self.queues = queues
@@ -283,14 +286,16 @@ class Consumer(object):
         else:
             self.ready_queue.put(task)
 
+    def on_control(self, message, message_data):
+        try:
+            self.pidbox_node.handle_message(message, message_data)
+        except KeyError:
+            self.logger.error("No such control command: %s" % command)
+
     def apply_eta_task(self, task):
         self.ready_queue.put(task)
         self.qos.decrement_eventually()
 
-    def on_control(self, control):
-        """Handle received remote control command."""
-        return self.control_dispatch.dispatch_from_message(control)
-
     def receive_message(self, message_data, message):
         """The callback called when a new message is received. """
 
@@ -314,11 +319,6 @@ class Consumer(object):
                 self.on_task(task)
             return
 
-        # Handle control command
-        control = message_data.get("control")
-        if control:
-            return self.on_control(control)
-
         warnings.warn(RuntimeWarning(
             "Received and deleted unknown message. Wrong destination?!? \
              the message was: %s" % message_data))
@@ -327,7 +327,7 @@ class Consumer(object):
     def maybe_conn_error(self, fun):
         try:
             fun()
-        except Exception:                   # TODO kombu.connection_errors
+        except self.connection_errors:
             pass
 
     def close_connection(self):
@@ -401,9 +401,9 @@ class Consumer(object):
         self.task_consumer.on_decode_error = self.on_decode_error
         self.task_consumer.register_callback(self.receive_message)
 
-        self.broadcast_consumer = mailbox(self.connection).Node(self.hostname)
-        self.broadcast_consumer.register_callback(self.receive_message)
-        self.control_dispatch.channel = self.broadcast_consumer.channel
+        self.pidbox_node.channel = self.connection.channel()
+        self.broadcast_consumer = self.pidbox_node.listen(
+                                        callback=self.on_control)
 
         # Flush events sent while connection was down.
         if self.event_dispatcher:

+ 1 - 72
celery/worker/control/__init__.py

@@ -1,78 +1,7 @@
 import socket
 
 from celery.app import app_or_default
-from celery.pidbox import mailbox
-from celery.utils import kwdict
 from celery.worker.control.registry import Panel
 
+# Loads the built-in remote control commands
 __import__("celery.worker.control.builtins")
-
-
-class ControlDispatch(object):
-    """Execute worker control panel commands."""
-    Panel = Panel
-
-    def __init__(self, logger=None, hostname=None, consumer=None, app=None,
-            channel=None):
-        self.app = app_or_default(app)
-        self.logger = logger or self.app.log.get_default_logger()
-        self.hostname = hostname or socket.gethostname()
-        self.consumer = consumer
-        self.channel = channel
-        self.panel = self.Panel(self.logger, self.consumer, self.hostname,
-                                app=self.app)
-
-    def reply(self, data, exchange, routing_key, **kwargs):
-
-        def _do_reply(connection=None, connect_timeout=None):
-            mailbox(connection).publish_reply(data, exchange, routing_key,
-                                              channel=self.channel)
-
-        self.app.with_default_connection(_do_reply)(**kwargs)
-
-    def dispatch_from_message(self, message):
-        """Dispatch by using message data received by the broker.
-
-        Example:
-
-            >>> def receive_message(message_data, message):
-            ...     control = message_data.get("control")
-            ...     if control:
-            ...         ControlDispatch().dispatch_from_message(control)
-
-        """
-        message = dict(message)             # don't modify callers message.
-        command = message.pop("command")
-        destination = message.pop("destination", None)
-        reply_to = message.pop("reply_to", None)
-        if not destination or self.hostname in destination:
-            return self.execute(command, message, reply_to=reply_to)
-
-    def execute(self, command, kwargs=None, reply_to=None):
-        """Execute control command by name and keyword arguments.
-
-        :param command: Name of the command to execute.
-        :param kwargs: Keyword arguments.
-
-        """
-        kwargs = kwargs or {}
-        control = None
-        try:
-            control = self.panel[command]
-        except KeyError:
-            self.logger.error("No such control command: %s" % command)
-        else:
-            try:
-                reply = control(self.panel, **kwdict(kwargs))
-            except SystemExit:
-                raise
-            except Exception, exc:
-                self.logger.error(
-                        "Error running control command %s kwargs=%s: %s" % (
-                            command, kwargs, exc))
-                reply = {"error": str(exc)}
-            if reply_to:
-                self.reply({self.hostname: reply},
-                           exchange=reply_to["exchange"],
-                           routing_key=reply_to["routing_key"])
-            return reply

+ 0 - 8
celery/worker/control/registry.py

@@ -6,14 +6,6 @@ from celery.app import app_or_default
 class Panel(UserDict):
     data = dict()                               # Global registry.
 
-    def __init__(self, logger, consumer, hostname=None, app=None):
-        self.app = app_or_default(app)
-        self.logger = logger
-        self.hostname = hostname
-        self.consumer = consumer
-        # Compat (pre 2.2)
-        self.listener = consumer
-
     @classmethod
     def register(cls, method, name=None):
         cls.data[name or method.__name__] = method