Explorar o código

Migrate: Moving tasks improvements

Ask Solem %!s(int64=12) %!d(string=hai) anos
pai
achega
1d29c10d8b
Modificáronse 1 ficheiros con 46 adicións e 15 borrados
  1. 46 15
      celery/contrib/migrate.py

+ 46 - 15
celery/contrib/migrate.py

@@ -14,7 +14,7 @@ import socket
 from functools import partial
 from itertools import cycle, islice
 
-from kombu import eventloop
+from kombu import eventloop, Queue
 from kombu.exceptions import StdChannelError
 from kombu.utils.encoding import ensure_bytes
 
@@ -98,13 +98,13 @@ def migrate_tasks(source, dest, migrate=migrate_task, app=None,
                         on_declare_queue=on_declare_queue, **kwargs)
 
 
-def move_tasks(conn, predicate, exchange, routing_key, app=None, **kwargs):
+def move(predicate, conn, exchange=None, routing_key=None, app=None, **kwargs):
     """Find tasks by filtering them and move the tasks to a new queue.
 
-    :param conn: Connection to use.
     :param predicate: Filter function with signature ``(body, message)``.
-    :param exchange: Destination exchange.
-    :param routing_key: Destination routing key.
+    :param conn: Connection to use.
+    :keyword exchange: Default destination exchange.
+    :keyword routing_key: Default destination routing key.
 
     Also supports the same keyword arguments as :func:`start_filter`.
 
@@ -117,21 +117,54 @@ def move_tasks(conn, predicate, exchange, routing_key, app=None, **kwargs):
             if body['id'] == wanted_id:
                 return True
 
-        move_tasks(conn, is_wanted_task, exchange, routing_key)
+        move(is_wanted_task, conn, exchange, routing_key)
+
+
+    The predicate may also return a tuple of ``(exchange, routing_key)``
+    to specify the destination to where the task should be moved,
+    or a :class:`~kombu.entitiy.Queue` instance.
+    Any other true value means that the task will be moved to the
+    default exchange/routing_key.
 
     """
     app = app_or_default(app)
     producer = app.amqp.TaskProducer(conn)
 
     def on_task(body, message):
-        if predicate(body, message):
+        ret = predicate(body, message)
+        if ret:
+            if isinstance(ret, Queue):
+                ex, rk = ret.exchange.name, ret.routing_key
+            else:
+                ex, rk = expand_dest(ret, exchange, routing_key)
             republish(producer, message,
-                      exchange=exchange, routing_key=routing_key)
+                      exchange=ex, routing_key=rk)
             message.ack()
 
     return start_filter(app, conn, on_task, **kwargs)
 
 
+def expand_dest(ret, exchange, routing_key):
+    try:
+        ex, rk = ret
+    except (TypeError, ValueError):
+        ex, rk = exchange, routing_key
+    return ex, rk
+
+
+
+# XXX Deprecated (arguments rearranged)
+move_tasks = lambda conn, pred, *a, **kw: move(pred, conn, *a, **kw)
+
+
+def task_id_eq(task_id, body, message):
+    return body['id'] == task_id
+
+
+def task_id_in(ids, body, message):
+    return body['id'] in ids
+
+
 def move_task_by_id(conn, task_id, exchange, routing_key, **kwargs):
     """Find a task by id and move it to another queue.
 
@@ -143,11 +176,8 @@ def move_task_by_id(conn, task_id, exchange, routing_key, **kwargs):
     Also supports the same keyword arguments as :func:`start_filter`.
 
     """
-    def predicate(body, message):
-        if body['id'] == task_id:
-            return True
-
-    return move_tasks(conn, predicate, exchange, routing_key, **kwargs)
+    return move(conn, partial(task_id_eq, task_id),
+                exchange, routing_key, **kwargs)
 
 
 def prepare_queues(queues):
@@ -163,7 +193,8 @@ def prepare_queues(queues):
 
 def start_filter(app, conn, filter, limit=None, timeout=1.0,
         ack_messages=False, migrate=migrate_task, tasks=None, queues=None,
-        callback=None, forever=False, on_declare_queue=None, **kwargs):
+        callback=None, forever=False, on_declare_queue=None,
+        consume_from=None, **kwargs):
     state = State()
     queues = prepare_queues(queues)
     if isinstance(tasks, basestring):
@@ -177,7 +208,7 @@ def start_filter(app, conn, filter, limit=None, timeout=1.0,
     def ack_message(body, message):
         message.ack()
 
-    consumer = app.amqp.TaskConsumer(conn)
+    consumer = app.amqp.TaskConsumer(conn, queues=consume_from)
 
     if tasks:
         filter = filter_callback(filter, tasks)