Browse Source

Merge branch 'dctrwatson/migration-tool-update'

Ask Solem 12 years ago
parent
commit
34e8a4ca51
2 changed files with 77 additions and 14 deletions
  1. 17 3
      celery/bin/celery.py
  2. 60 11
      celery/contrib/migrate.py

+ 17 - 3
celery/bin/celery.py

@@ -647,8 +647,21 @@ class migrate(Command):
     NOTE: This command is experimental, make sure you have
           a backup of the tasks before you continue.
     """
-    def usage(self, command):
-        return '%%prog %s <source_url> <dest_url>' % (command, )
+    args = '<source_url> <dest_url>'
+    option_list = Command.option_list + (
+            Option('--limit', '-n', type='int',
+                    help='Number of tasks to consume (int)'),
+            Option('--timeout', '-t', type='float', default=1.0,
+                    help='Timeout in seconds (float) waiting for tasks'),
+            Option('--ack-messages', '-a', action='store_true',
+                    help='Ack messages from source broker.'),
+            Option('--tasks', '-T',
+                    help='List of task names to filter on.'),
+            Option('--queues', '-Q',
+                    help='List of queues to migrate.'),
+            Option('--forever', '-F', action='store_true',
+                    help='Continually migrate tasks until killed.'),
+    )
 
     def on_migrate_task(self, state, body, message):
         self.out('Migrating task %s/%s: %s[%s]' % (
@@ -662,7 +675,8 @@ class migrate(Command):
 
         migrate_tasks(Connection(args[0]),
                       Connection(args[1]),
-                      callback=self.on_migrate_task)
+                      callback=self.on_migrate_task,
+                      **kwargs)
 migrate = command(migrate)
 
 

+ 60 - 11
celery/contrib/migrate.py

@@ -12,6 +12,7 @@ from __future__ import with_statement
 import socket
 
 from functools import partial
+from itertools import cycle, islice
 
 from kombu import eventloop
 from kombu.exceptions import StdChannelError
@@ -31,10 +32,11 @@ class State(object):
         return unicode(self.total_apx)
 
 
-def migrate_task(producer, body_, message,
+def migrate_task(producer, body_, message, queues=None,
         remove_props=['application_headers',
                       'content_type',
-                      'content_encoding']):
+                      'content_encoding',
+                      'headers']):
     body = ensure_bytes(message.body)  # use raw message body.
     info, headers, props = (message.delivery_info,
                             message.headers,
@@ -47,8 +49,11 @@ def migrate_task(producer, body_, message,
     for key in remove_props:
         props.pop(key, None)
 
-    producer.publish(ensure_bytes(body), exchange=info['exchange'],
-                           routing_key=info['routing_key'],
+    exchange = queues.get(info['exchange'], info['exchange'])
+    routing_key = queues.get(info['routing_key'], info['routing_key'])
+
+    producer.publish(ensure_bytes(body), exchange=exchange,
+                           routing_key=routing_key,
                            compression=compression,
                            headers=headers,
                            content_type=ctype,
@@ -56,27 +61,71 @@ def migrate_task(producer, body_, message,
                            **props)
 
 
-def migrate_tasks(source, dest, timeout=1.0, app=None,
-        migrate=None, callback=None):
+def filter_callback(callback, tasks):
+    def filtered(body, message):
+        if tasks and message.payload['task'] not in tasks:
+            return
+
+        return callback(body, message)
+    return filtered
+
+
+def migrate_tasks(source, dest, limit=None, timeout=1.0, ack_messages=False,
+        app=None, migrate=migrate_task, tasks=None, queues=None, callback=None,
+        forever=False, **kwargs):
     state = State()
     app = app_or_default(app)
 
+    if isinstance(queues, basestring):
+        queues = queues.split(',')
+    if isinstance(queues, list):
+        queues = dict([tuple(islice(cycle(q.split(':')), None, 2)) for q in queues])
+    if queues is None:
+        queues = {}
+
+    if isinstance(tasks, basestring):
+        tasks = set(tasks.split(','))
+    if tasks is None:
+        tasks = set([])
+
     def update_state(body, message):
         state.count += 1
 
+    def ack_message(body, message):
+        message.ack()
+
     producer = app.amqp.TaskProducer(dest)
-    if migrate is None:
-        migrate = partial(migrate_task, producer)
+    migrate = partial(migrate, producer, queues=queues)
     consumer = app.amqp.TaskConsumer(source)
+
+    if tasks:
+        migrate = filter_callback(migrate, tasks)
+        update_state = filter_callback(update_state, tasks)
+        ack_message = filter_callback(ack_message, tasks)
+
+    consumer.register_callback(migrate)
     consumer.register_callback(update_state)
+    if ack_messages:
+        consumer.register_callback(ack_message)
     if callback is not None:
         callback = partial(callback, state)
+        if tasks:
+            callback = filter_callback(callback, tasks)
         consumer.register_callback(callback)
-    consumer.register_callback(migrate)
 
     # declare all queues on the new broker.
     for queue in consumer.queues:
-        queue(producer.channel).declare()
+        if queues and queue.name not in queues:
+            continue
+
+        new_queue = queue(producer.channel)
+        new_queue.name = queues.get(queue.name, queue.name)
+        if new_queue.routing_key == queue.name:
+            new_queue.routing_key = queues.get(queue.name, new_queue.routing_key)
+        if new_queue.exchange.name == queue.name:
+            new_queue.exchange.name = queues.get(queue.name, queue.name)
+        new_queue.declare()
+
         try:
             _, mcount, _ = queue(consumer.channel).queue_declare(passive=True)
             if mcount:
@@ -87,7 +136,7 @@ def migrate_tasks(source, dest, timeout=1.0, app=None,
     # start migrating messages.
     with consumer:
         try:
-            for _ in eventloop(source, timeout=timeout):  # pragma: no cover
+            for _ in eventloop(source, limit=limit, timeout=timeout, ignore_timeouts=forever):  # pragma: no cover
                 pass
         except socket.timeout:
             return