Browse Source

Routing now works as described in the Routing User Guide.

Ask Solem 15 years ago
parent
commit
7049286964
6 changed files with 50 additions and 34 deletions
  1. 1 1
      Changelog
  2. 15 6
      celery/execute/__init__.py
  3. 5 17
      celery/messaging.py
  4. 26 6
      celery/routes.py
  5. 2 3
      celery/tests/test_routes.py
  6. 1 1
      docs/userguide/routing.rst

+ 1 - 1
Changelog

@@ -306,7 +306,7 @@ News
 
 
         class Router(object):
         class Router(object):
 
 
-            def route_for_task(self, task, task_id=None, args=None, kwargs=None):
+            def route_for_task(self, task, args=None, kwargs=None):
                 if task == "celery.ping":
                 if task == "celery.ping":
                     return "default"
                     return "default"
 
 

+ 15 - 6
celery/execute/__init__.py

@@ -1,13 +1,14 @@
 from celery import conf
 from celery import conf
-from celery.utils import gen_unique_id, fun_takes_kwargs, mattrgetter
-from celery.result import AsyncResult, EagerResult
+from celery.datastructures import ExceptionInfo
 from celery.execute.trace import TaskTrace
 from celery.execute.trace import TaskTrace
-from celery.registry import tasks
 from celery.messaging import with_connection
 from celery.messaging import with_connection
 from celery.messaging import TaskPublisher
 from celery.messaging import TaskPublisher
-from celery.datastructures import ExceptionInfo
+from celery.registry import tasks
+from celery.result import AsyncResult, EagerResult
+from celery.routes import route
+from celery.utils import gen_unique_id, fun_takes_kwargs, mattrgetter
 
 
-extract_exec_options = mattrgetter("routing_key", "exchange",
+extract_exec_options = mattrgetter("queue", "routing_key", "exchange",
                                    "immediate", "mandatory",
                                    "immediate", "mandatory",
                                    "priority", "serializer",
                                    "priority", "serializer",
                                    "delivery_mode")
                                    "delivery_mode")
@@ -16,7 +17,7 @@ extract_exec_options = mattrgetter("routing_key", "exchange",
 @with_connection
 @with_connection
 def apply_async(task, args=None, kwargs=None, countdown=None, eta=None,
 def apply_async(task, args=None, kwargs=None, countdown=None, eta=None,
         task_id=None, publisher=None, connection=None, connect_timeout=None,
         task_id=None, publisher=None, connection=None, connect_timeout=None,
-        **options):
+        routes=None, routing_table=None, **options):
     """Run a task asynchronously by the celery daemon(s).
     """Run a task asynchronously by the celery daemon(s).
 
 
     :param task: The :class:`~celery.task.base.Task` to run.
     :param task: The :class:`~celery.task.base.Task` to run.
@@ -78,11 +79,19 @@ def apply_async(task, args=None, kwargs=None, countdown=None, eta=None,
     replaced by a local :func:`apply` call instead.
     replaced by a local :func:`apply` call instead.
 
 
     """
     """
+    if routes is None:
+        routes = conf.ROUTES
+    if routing_table is None:
+        routing_table = conf.get_routing_table()
+
     if conf.ALWAYS_EAGER:
     if conf.ALWAYS_EAGER:
         return apply(task, args, kwargs, task_id=task_id)
         return apply(task, args, kwargs, task_id=task_id)
 
 
     task = tasks[task.name] # get instance from registry
     task = tasks[task.name] # get instance from registry
+
     options = dict(extract_exec_options(task), **options)
     options = dict(extract_exec_options(task), **options)
+    options = route(routes, options, routing_table,
+                    task.name, args, kwargs)
     exchange = options.get("exchange")
     exchange = options.get("exchange")
     exchange_type = options.get("exchange_type")
     exchange_type = options.get("exchange_type")
 
 

+ 5 - 17
celery/messaging.py

@@ -18,9 +18,8 @@ from celery.routes import lookup_route, expand_destination
 from celery.loaders import load_settings
 from celery.loaders import load_settings
 
 
 
 
-MSG_OPTIONS = ("mandatory", "priority",
-               "immediate", "routing_key",
-               "serializer", "delivery_mode")
+MSG_OPTIONS = ("mandatory", "priority", "immediate",
+               "routing_key", "serializer", "delivery_mode")
 
 
 get_msg_options = mitemgetter(*MSG_OPTIONS)
 get_msg_options = mitemgetter(*MSG_OPTIONS)
 extract_msg_options = lambda d: dict(zip(MSG_OPTIONS, get_msg_options(d)))
 extract_msg_options = lambda d: dict(zip(MSG_OPTIONS, get_msg_options(d)))
@@ -56,12 +55,11 @@ class TaskPublisher(Publisher):
         """Delay task for execution by the celery nodes."""
         """Delay task for execution by the celery nodes."""
 
 
         task_id = task_id or gen_unique_id()
         task_id = task_id or gen_unique_id()
-
+        task_args = task_args or []
+        task_kwargs = task_kwargs or {}
         if countdown: # Convert countdown to ETA.
         if countdown: # Convert countdown to ETA.
             eta = datetime.now() + timedelta(seconds=countdown)
             eta = datetime.now() + timedelta(seconds=countdown)
 
 
-        task_args = task_args or []
-        task_kwargs = task_kwargs or {}
         if not isinstance(task_args, (list, tuple)):
         if not isinstance(task_args, (list, tuple)):
             raise ValueError("task args must be a list or tuple")
             raise ValueError("task args must be a list or tuple")
         if not isinstance(task_kwargs, dict):
         if not isinstance(task_kwargs, dict):
@@ -79,17 +77,7 @@ class TaskPublisher(Publisher):
         if taskset_id:
         if taskset_id:
             message_data["taskset"] = taskset_id
             message_data["taskset"] = taskset_id
 
 
-        route = {}
-        if conf.ROUTES:
-            route = lookup_route(conf.ROUTES, task_name, task_id,
-                                 task_args, task_kwargs)
-        if route:
-            dest = expand_destination(route, conf.get_routing_table())
-            msg_options = dict(extract_msg_options(kwargs), **dest)
-        else:
-            msg_options = extract_msg_options(kwargs)
-
-        self.send(message_data, **msg_options)
+        self.send(message_data, **extract_msg_options(kwargs))
         signals.task_sent.send(sender=task_name, **message_data)
         signals.task_sent.send(sender=task_name, **message_data)
 
 
         return task_id
         return task_id

+ 26 - 6
celery/routes.py

@@ -13,15 +13,25 @@ class MapRoute(object):
 
 
 
 
 def expand_destination(route, routing_table):
 def expand_destination(route, routing_table):
+    # The route can simply be a queue name,
+    # this is convenient for direct exchanges.
     if isinstance(route, basestring):
     if isinstance(route, basestring):
+        queue, route = route, {}
+    else:
+        # For topic exchanges you can use the defaults from a queue
+        # definition, and override e.g. just the routing_key.
+        queue = route.pop("queue", None)
+
+    if queue:
         try:
         try:
-            dest = dict(routing_table[route])
+            dest = dict(routing_table[queue])
         except KeyError:
         except KeyError:
             raise RouteNotFound(
             raise RouteNotFound(
                 "Route %s does not exist in the routing table "
                 "Route %s does not exist in the routing table "
                 "(CELERY_QUEUES)" % route)
                 "(CELERY_QUEUES)" % route)
         dest.setdefault("routing_key", dest.get("binding_key"))
         dest.setdefault("routing_key", dest.get("binding_key"))
-        return dest
+        return dict(route, **dest)
+
     return route
     return route
 
 
 
 
@@ -40,6 +50,16 @@ def prepare(routes):
     return map(expand_route, routes)
     return map(expand_route, routes)
 
 
 
 
+def route(routes, options, routing_table, task, args=(), kwargs={}):
+    # Expand "queue" keys in options.
+    options = expand_destination(options, routing_table)
+    if routes:
+        route = lookup_route(routes, task, args, kwargs)
+        # Also expand "queue" keys in route.
+        return dict(options, **expand_destination(route, routing_table))
+    return options
+
+
 def firstmatcher(method):
 def firstmatcher(method):
     """Returns a functions that with a list of instances,
     """Returns a functions that with a list of instances,
     finds the first instance that returns a value for the given method."""
     finds the first instance that returns a value for the given method."""
@@ -59,9 +79,9 @@ _first_route = firstmatcher("route_for_task")
 _first_disabled = firstmatcher("disabled")
 _first_disabled = firstmatcher("disabled")
 
 
 
 
-def lookup_route(routes, task, task_id=None, args=None, kwargs=None):
-    return _first_route(routes, task, task_id, args, kwargs)
+def lookup_route(routes, task, args=None, kwargs=None):
+    return _first_route(routes, task, args, kwargs)
 
 
 
 
-def lookup_disabled(routes, task, task_id=None, args =None, kwargs=None):
-    return _first_disabled(routes, task, task_id, args, kwargs)
+def lookup_disabled(routes, task, args=None, kwargs=None):
+    return _first_disabled(routes, task, args, kwargs)

+ 2 - 3
celery/tests/test_routes.py

@@ -3,7 +3,6 @@ import unittest2 as unittest
 
 
 from celery import conf
 from celery import conf
 from celery import routes
 from celery import routes
-from celery.utils import gen_unique_id
 from celery.utils.functional import wraps
 from celery.utils.functional import wraps
 from celery.exceptions import RouteNotFound
 from celery.exceptions import RouteNotFound
 
 
@@ -69,7 +68,7 @@ class test_lookup_route(unittest.TestCase):
         R = routes.prepare(({"celery.ping": "bar"},
         R = routes.prepare(({"celery.ping": "bar"},
                             {"celery.ping": "foo"}))
                             {"celery.ping": "foo"}))
         self.assertDictContainsSubset(b_route,
         self.assertDictContainsSubset(b_route,
-                expand(routes.lookup_route(R, "celery.ping", gen_unique_id(),
+                expand(routes.lookup_route(R, "celery.ping",
                     args=[1, 2], kwargs={})))
                     args=[1, 2], kwargs={})))
 
 
     @with_queues(foo=a_route, bar=b_route)
     @with_queues(foo=a_route, bar=b_route)
@@ -78,7 +77,7 @@ class test_lookup_route(unittest.TestCase):
         R = routes.prepare(({"celery.xaza": "bar"},
         R = routes.prepare(({"celery.xaza": "bar"},
                             {"celery.ping": "foo"}))
                             {"celery.ping": "foo"}))
         self.assertDictContainsSubset(a_route,
         self.assertDictContainsSubset(a_route,
-                expand(routes.lookup_route(R, "celery.ping", gen_unique_id(),
+                expand(routes.lookup_route(R, "celery.ping",
                     args=[1, 2], kwargs={})))
                     args=[1, 2], kwargs={})))
         self.assertIsNone(routes.lookup_route(R, "celery.poza"))
         self.assertIsNone(routes.lookup_route(R, "celery.poza"))
 
 

+ 1 - 1
docs/userguide/routing.rst

@@ -413,7 +413,7 @@ All you need to define a new router is to create a class with a
 
 
     class MyRouter(object):
     class MyRouter(object):
 
 
-        def route_for_task(task, task_id=None, args=None, kwargs=None):
+        def route_for_task(task, args=None, kwargs=None):
             if task == "myapp.tasks.compress_video":
             if task == "myapp.tasks.compress_video":
                 return {"exchange": "video",
                 return {"exchange": "video",
                         "exchange_type": "topic",
                         "exchange_type": "topic",