Browse Source

Migrate: adds move_tasks and move_task_by_id, tools to move tasks from one queue to another

Ask Solem 12 years ago
parent
commit
aa12fb4b24
3 changed files with 108 additions and 46 deletions
  1. 3 0
      celery/bin/base.py
  2. 28 11
      celery/bin/celery.py
  3. 77 35
      celery/contrib/migrate.py

+ 3 - 0
celery/bin/base.py

@@ -144,6 +144,9 @@ class Command(object):
     #: Text to print in --help before option list.
     description = ''
 
+    #: Set to true if this command doesn't have subcommands
+    leaf = True
+
     def __init__(self, app=None, get_app=None):
         self.app = app
         self.get_app = get_app or self._get_default_app

+ 28 - 11
celery/bin/celery.py

@@ -11,6 +11,7 @@ from __future__ import with_statement
 
 import anyjson
 import sys
+import warnings
 
 from billiard import freeze_support
 from importlib import import_module
@@ -19,6 +20,7 @@ from pprint import pformat
 from celery.platforms import EX_OK, EX_FAILURE, EX_UNAVAILABLE, EX_USAGE
 from celery.utils import term
 from celery.utils import text
+from celery.utils.functional import memoize
 from celery.utils.imports import symbol_by_name
 from celery.utils.timeutils import maybe_iso8601
 
@@ -35,11 +37,19 @@ Type '%(prog_name)s <command> --help' for help using a specific command.
 
 commands = {}
 
-command_classes = (
+command_classes = [
     ('Main', ['worker', 'events', 'beat', 'shell', 'multi', 'amqp'], 'green'),
     ('Remote Control', ['status', 'inspect', 'control'], 'blue'),
     ('Utils', ['purge', 'list', 'migrate', 'call', 'result', 'report'], None),
-)
+]
+
+
+@memoize()
+def _get_extension_classes():
+    extensions = []
+    command_classes.append(('Extensions', extensions, 'magenta'))
+    return extensions
+
 
 class Error(Exception):
 
@@ -65,19 +75,22 @@ def get_extension_commands(namespace='celery.commands'):
         return
 
     for ep in iter_entry_points(namespace):
-        for attr in ep.attrs:
-            command(symbol_by_name(':'.join([ep.module_name, attr])),
-                    name=ep.name)
+        _get_extension_classes().append(ep.name)
+        sym = ':'.join([ep.module_name, ep.attrs[0]])
+        try:
+            cls = symbol_by_name(sym)
+        except (ImportError, SyntaxError), exc:
+            warnings.warn('Cannot load extension %r: %r' % (sym, exc))
+        else:
+            command(cls, name=ep.name)
 get_extension_commands()
 
 
-
 class Command(BaseCommand):
     help = ''
     args = ''
     prog_name = 'celery'
     show_body = True
-    leaf = True
     show_reply = True
 
     option_list = (
@@ -915,10 +928,14 @@ def main():
     # Fix for setuptools generated scripts, so that it will
     # work with multiprocessing fork emulation.
     # (see multiprocessing.forking.get_preparation_data())
-    if __name__ != '__main__':  # pragma: no cover
-        sys.modules['__main__'] = sys.modules[__name__]
-    freeze_support()
-    CeleryCommand().execute_from_commandline()
+    try:
+        if __name__ != '__main__':  # pragma: no cover
+            sys.modules['__main__'] = sys.modules[__name__]
+        freeze_support()
+        CeleryCommand().execute_from_commandline()
+    except KeyboardInterrupt:
+        pass
+
 
 if __name__ == '__main__':          # pragma: no cover
     main()

+ 77 - 35
celery/contrib/migrate.py

@@ -31,17 +31,20 @@ class State(object):
             return u'?'
         return unicode(self.total_apx)
 
+    def __repr__(self):
+        return '%s/%s' % (self.count, self.strtotal)
 
-def migrate_task(producer, body_, message, queues=None,
+
+def republish(producer, message, exchange=None, routing_key=None,
         remove_props=['application_headers',
                       'content_type',
                       'content_encoding',
                       'headers']):
-    queues = {} if queues is None else queues
     body = ensure_bytes(message.body)  # use raw message body.
     info, headers, props = (message.delivery_info,
-                            message.headers,
-                            message.properties)
+                            message.headers, message.properties)
+    exchange = info['exchange'] if exchange is None else exchange
+    routing_key = info['routing_key'] if routing_key is None else routing_key
     ctype, enc = message.content_type, message.content_encoding
     # remove compression header, as this will be inserted again
     # when the message is recompressed.
@@ -50,19 +53,22 @@ def migrate_task(producer, body_, message, queues=None,
     for key in remove_props:
         props.pop(key, None)
 
-    exchange = queues.get(info['exchange'], info['exchange'])
-    routing_key = queues.get(info['routing_key'], info['routing_key'])
-
     producer.publish(ensure_bytes(body), exchange=exchange,
-                           routing_key=routing_key,
-                           compression=compression,
-                           headers=headers,
-                           content_type=ctype,
-                           content_encoding=enc,
-                           **props)
+                     routing_key=routing_key, compression=compression,
+                     headers=headers, content_type=ctype,
+                     content_encoding=enc, **props)
+
+
+def migrate_task(producer, body_, message, queues=None):
+    info = message.delivery_info
+    queues = {} if queues is None else queues
+    republish(producer, message,
+              exchange=queues.get(info['exchange']),
+              routing_key=queues.get(info['routing_key']))
 
 
 def filter_callback(callback, tasks):
+
     def filtered(body, message):
         if tasks and message.payload['task'] not in tasks:
             return
@@ -71,12 +77,50 @@ def filter_callback(callback, tasks):
     return filtered
 
 
-def migrate_tasks(source, dest, limit=None, timeout=1.0, ack_messages=False,
-        app=None, migrate=migrate_task, tasks=None, queues=None, callback=None,
-        forever=False, **kwargs):
-    state = State()
+def migrate_tasks(source, dest, migrate=migrate_task, app=None,
+        queues=None, **kwargs):
+    app = app_or_default(app)
+    queues = prepare_queues(queues)
+    producer = app.amqp.TaskProducer(dest)
+    migrate = partial(migrate, producer, queues=queues)
+
+    def on_declare_queue(queue):
+        new_queue = queue(producer.channel)
+        new_queue.name = queues.get(queue.name, queue.name)
+        if new_queue.routing_key == queue.name:
+            new_queue.routing_key = queues.get(queue.name,
+                                               new_queue.routing_key)
+        if new_queue.exchange.name == queue.name:
+            new_queue.exchange.name = queues.get(queue.name, queue.name)
+        new_queue.declare()
+
+    return start_filter(app, source, migrate, queues=queues,
+                        on_declare_queue=on_declare_queue, **kwargs)
+
+
+def move_tasks(conn, predicate, exchange, routing_key, app=None, **kwargs):
     app = app_or_default(app)
+    producer = app.amqp.TaskProducer(conn)
+
+    def on_task(body, message):
+        if predicate(body, message):
+            republish(producer, message,
+                      exchange=exchange, routing_key=routing_key)
+            message.ack()
 
+    return start_filter(app, conn, on_task, **kwargs)
+
+
+def move_task_by_id(conn, task_id, exchange, routing_key, **kwargs):
+
+    def predicate(body, message):
+        if body['id'] == task_id:
+            return True
+
+    return move_tasks(conn, predicate, exchange, routing_key, **kwargs)
+
+
+def prepare_queues(queues):
     if isinstance(queues, basestring):
         queues = queues.split(',')
     if isinstance(queues, list):
@@ -85,6 +129,12 @@ def migrate_tasks(source, dest, limit=None, timeout=1.0, ack_messages=False,
     if queues is None:
         queues = {}
 
+
+def start_filter(app, conn, filter, limit=None, timeout=1.0,
+        ack_messages=False, migrate=migrate_task, tasks=None, queues=None,
+        callback=None, forever=False, on_declare_queue=None, **kwargs):
+    state = State()
+    queues = prepare_queues(queues)
     if isinstance(tasks, basestring):
         tasks = set(tasks.split(','))
     if tasks is None:
@@ -96,16 +146,14 @@ def migrate_tasks(source, dest, limit=None, timeout=1.0, ack_messages=False,
     def ack_message(body, message):
         message.ack()
 
-    producer = app.amqp.TaskProducer(dest)
-    migrate = partial(migrate, producer, queues=queues)
-    consumer = app.amqp.TaskConsumer(source)
+    consumer = app.amqp.TaskConsumer(conn)
 
     if tasks:
-        migrate = filter_callback(migrate, tasks)
+        filter = filter_callback(filter, tasks)
         update_state = filter_callback(update_state, tasks)
         ack_message = filter_callback(ack_message, tasks)
 
-    consumer.register_callback(migrate)
+    consumer.register_callback(filter)
     consumer.register_callback(update_state)
     if ack_messages:
         consumer.register_callback(ack_message)
@@ -119,28 +167,22 @@ def migrate_tasks(source, dest, limit=None, timeout=1.0, ack_messages=False,
     for queue in consumer.queues:
         if queues and queue.name not in queues:
             continue
-
-        new_queue = queue(producer.channel)
-        new_queue.name = queues.get(queue.name, queue.name)
-        if new_queue.routing_key == queue.name:
-            new_queue.routing_key = queues.get(queue.name,
-                                               new_queue.routing_key)
-        if new_queue.exchange.name == queue.name:
-            new_queue.exchange.name = queues.get(queue.name, queue.name)
-        new_queue.declare()
-
+        if on_declare_queue is not None:
+            on_declare_queue(queue)
         try:
             _, mcount, _ = queue(consumer.channel).queue_declare(passive=True)
             if mcount:
                 state.total_apx += mcount
-        except source.channel_errors + (StdChannelError, ):
+        except conn.channel_errors + (StdChannelError, ):
             pass
 
     # start migrating messages.
     with consumer:
         try:
-            for _ in eventloop(source, limit=limit,  # pragma: no cover
+            for _ in eventloop(conn, limit=limit,  # pragma: no cover
                                timeout=timeout, ignore_timeouts=forever):
                 pass
         except socket.timeout:
-            return
+            pass
+    return state
+