Forráskód Böngészése

Added support for router classes (like the django multidb routers)

New setting: CELERY_ROUTES

This is a single, or a list of routers to traverse when
sending tasks. Dicts in this list converts to a "celery.routes.MapRoute"
instance.

Example:

    >>> CELERY_ROUTES = {"celery.ping": "default",
                          "mytasks.add": "cpu-bound",
                          "video.encode": {
                            "queue": "video",
                            "exchange": "media"
                            "routing_key": "media.video.encode"}}

    >>> CELERY_ROUTES = ("myapp.tasks.Router",
                         {"celery.ping": "default})

Where myapp.tasks.Router could be:

    class Router(object):

        def route_for_task(self, task, task_id=None, args=None, kwargs=None):
            if task == "celery.ping":
                return "default"

route_for_task may return a string or a dict. A string then means
it's a queue name in CELERY_QUEUES, a dict means it's a custom route.

When sending tasks, the routers are consulted in order. The first
router that doesn't return None is the route to use. The message options
is then merged with the found route settings, where the routers settings
have priority.

Example if apply_async has these arguments::

    Task.apply_async(immediate=False, exchange="video",
routing_key="video.compress")

and a router returns::

    {"immediate": True,
     "exchange": "urgent"}

the final message options will be::

    immediate=True, exchange="urgent", routing_key="video.compress"

(and any default message options defined in the Task class)

Closes #79
Ask Solem 15 éve
szülő
commit
6d701af9e6
6 módosított fájl, 216 hozzáadás és 96 törlés
  1. 18 0
      celery/conf.py
  2. 4 0
      celery/exceptions.py
  3. 12 1
      celery/messaging.py
  4. 68 0
      celery/routes.py
  5. 0 95
      celery/task/route.py
  6. 114 0
      celery/tests/test_routes.py

+ 18 - 0
celery/conf.py

@@ -2,6 +2,7 @@ import logging
 import warnings
 from datetime import timedelta
 
+from celery import routes
 from celery.loaders import load_settings
 
 DEFAULT_PROCESS_LOG_FMT = """
@@ -32,6 +33,7 @@ _DEFAULTS = {
     "CELERYD_TASK_TIME_LIMIT": None,
     "CELERYD_TASK_SOFT_TIME_LIMIT": None,
     "CELERYD_MAX_TASKS_PER_CHILD": None,
+    "CELERY_ROUTES": None,
     "CELERY_DEFAULT_ROUTING_KEY": "celery",
     "CELERY_DEFAULT_QUEUE": "celery",
     "CELERY_DEFAULT_EXCHANGE": "celery",
@@ -156,6 +158,22 @@ QUEUES = _get("CELERY_QUEUES") or {DEFAULT_QUEUE: {
                                        "exchange_type": DEFAULT_EXCHANGE_TYPE,
                                        "binding_key": DEFAULT_ROUTING_KEY}}
 
+# CELERY_ROUTES initialization
+"""
+
+    >>> CELERY_ROUTES = {"celery.ping": "default",
+                          "mytasks.add": "cpu-bound",
+                          "video.encode": {
+                            "queue": "video",
+                            "exchange": "media"
+                            "routing_key": "media.video.encode"}}
+
+    >>> CELERY_ROUTES = ("myapp.tasks.Router",
+                         {"celery.ping": "default})
+
+"""
+
+ROUTES = routes.prepare(_get("CELERY_ROUTES") or [])
 # :--- Broadcast queue settings                     <-   --   --- - ----- -- #
 
 BROADCAST_QUEUE = _get("CELERY_BROADCAST_QUEUE")

+ 4 - 0
celery/exceptions.py

@@ -10,6 +10,10 @@ Task of kind %s is not registered, please make sure it's imported.
 """.strip()
 
 
+class RouteNotFound(KeyError):
+    """Task routed to a queue not in the routing table (CELERY_QUEUES)."""
+
+
 class SoftTimeLimitExceeded(_SoftTimeLimitExceeded):
     """The soft time limit has been exceeded. This exception is raised
     to give the task a chance to clean up."""

+ 12 - 1
celery/messaging.py

@@ -14,6 +14,7 @@ from billiard.utils.functional import wraps
 from celery import conf
 from celery import signals
 from celery.utils import gen_unique_id, mitemgetter, noop
+from celery.routes import lookup_route, expand_destination
 from celery.loaders import load_settings
 
 
@@ -78,7 +79,17 @@ class TaskPublisher(Publisher):
         if taskset_id:
             message_data["taskset"] = taskset_id
 
-        self.send(message_data, **extract_msg_options(kwargs))
+        route = {}
+        if conf.ROUTES:
+            route = lookup_route(conf.ROUTES, task_name, task_id,
+                                 task_args, task_kwargs)
+        if route:
+            dest = expand_destination(route, conf.get_routing_table())
+            msg_options = dict(extract_msg_options(kwargs), **dest)
+        else:
+            msg_options = extract_msg_options(kwargs)
+
+        self.send(message_data, **msg_options)
         signals.task_sent.send(sender=task_name, **message_data)
 
         return task_id

+ 68 - 0
celery/routes.py

@@ -0,0 +1,68 @@
+from celery.utils import instantiate
+from celery.exceptions import RouteNotFound
+
+
+# Route from mapping
+class MapRoute(object):
+
+    def __init__(self, map):
+        self.map = map
+
+    def route_for_task(self, task, *args, **kwargs):
+        return self.map.get(task)
+
+
+def expand_destination(route, routing_table):
+    if isinstance(route, basestring):
+        try:
+            dest = dict(routing_table[route])
+        except KeyError, exc:
+            raise RouteNotFound(
+                "Route %s does not exist in the routing table "
+                "(CELERY_QUEUES)" % route)
+        dest.setdefault("routing_key", dest.get("binding_key"))
+        return dest
+    return route
+
+
+def prepare(routes):
+    """Expand ROUTES setting."""
+
+    def expand_route(route):
+        if isinstance(route, dict):
+            return MapRoute(route)
+        if isinstance(route, basestring):
+            return instantiate(route)
+        return route
+
+    if not hasattr(routes, "__iter__"):
+        routes = (routes, )
+    return map(expand_route, routes)
+
+
+
+def firstmatcher(method):
+    """With a list of instances, find the first instance that returns a
+    value for the given method."""
+
+    def _matcher(seq, *args, **kwargs):
+        for cls in seq:
+            try:
+                answer = getattr(cls, method)(*args, **kwargs)
+                if answer is not None:
+                    return answer
+            except AttributeError:
+                pass
+    return _matcher
+
+
+_first_route = firstmatcher("route_for_task")
+_first_disabled = firstmatcher("disabled")
+
+
+def lookup_route(routes, task, task_id=None, args=None, kwargs=None):
+    return _first_route(routes, task, task_id, args, kwargs)
+
+
+def lookup_disabled(routes, task, task_id=None, args =None, kwargs=None):
+    return _first_disabled(routes, task, task_id, args, kwargs)

+ 0 - 95
celery/task/route.py

@@ -1,95 +0,0 @@
-from celery import conf
-from celery.utils import get_cls_by_name
-
-default_queue = conf.routing_table[conf.DEFAULT_QUEUE]
-
-
-# Custom Router
-class Router(object):
-
-    def route_for_task(self, task, task_id=None, args=None, kwargs=None):
-        return {"queue": conf.DEFAULT_QUEUE,
-                "exchange": default_queue["exchange"],
-                "routing_key": conf.DEFAULT_ROUTING_KEY}
-
-    def disabled(self, task, task_id=None, args=None, kwargs=None):
-        if task.name == "celery.ping":
-            return True
-
-
-
-# Route from mapping
-class MapRoute(object):
-
-    def __init__(self, map):
-        self.map = dict((name, self._expand_destination(entry)
-                            for name, entry in map.items())
-
-    def route_for_task(self, task, **kwargs):
-        return self.map.get(task.name)
-
-    def _expand_destination(self, entry):
-        if isinstance(entry, basestring):
-            dest = dict(conf.routing_table[entry])
-            dest.setdefault("routing_key", dest.get("binding_key"))
-            return dest
-        return entry
-
-
-
-# CELERY_ROUTES initialization
-"""
-
-    >>> CELERY_ROUTES = {"celery.ping": "default",
-                          "mytasks.add": "cpu-bound",
-                          "video.encode": {
-                            "queue": "video",
-                            "exchange": "media"
-                            "routing_key": "media.video.encode"}}
-
-    >>> CELERY_ROUTES = ("myapp.tasks.Router",
-                         {"celery.ping": "default})
-
-"""
-
-
-def expand_route(route):
-    if hasattr(route, "items"):
-        return MapRoute(route)
-    if isinstance(route, "basestring"):
-        return get_cls_by_name(route)()
-    return route
-
-routes = _get("CELERY_ROUTES", [])
-if not hasattr(routes, "__iter__"):
-    routes = (routes, )
-routes = map(expand_route, routes)
-
-
-# Traversing routes
-
-def firstmatcher(seq, method):
-    """With a list of instances, find the first instance that returns a
-    value for the given method."""
-
-    def _matcher(*args, **kwargs):
-        for cls in seq:
-            try:
-                answer = getattr(cls, method)(*args, **kwargs)
-                if answer:
-                    return answer
-            except AttributeError:
-                pass
-
-    return _matcher
-
-
-_first_route = firstmatcher(routes, "route_for_task")
-_first_disabled = firstmatcher(routes, "disabled")
-
-def lookup_route(task, task_id=None, args=None, kwargs=None):
-    return _first_route(task, task_id, args, kwargs)
-
-def lookup_disabled(task, task_id=None, args=None, kwargs=None):
-    return _first_disabled(task, task_id, args, kwargs)
-

+ 114 - 0
celery/tests/test_routes.py

@@ -0,0 +1,114 @@
+import unittest2 as unittest
+
+from billiard.utils.functional import wraps
+
+from celery import conf
+from celery import routes
+from celery.utils import gen_unique_id
+from celery.exceptions import RouteNotFound
+
+
+def E(routing_table):
+    def expand(answer):
+        return routes.expand_destination(answer, routing_table)
+    return expand
+
+
+def with_queues(**queues):
+
+    def patch_fun(fun):
+        @wraps(fun)
+        def __inner(*args, **kwargs):
+            prev_queues = conf.QUEUES
+            conf.QUEUES = queues
+            try:
+                return fun(*args, **kwargs)
+            finally:
+                conf.QUEUES = prev_queues
+        return __inner
+    return patch_fun
+
+
+a_route = {"exchange": "fooexchange",
+           "exchange_type": "fanout",
+               "binding_key": "xuzzy"}
+b_route = {"exchange": "barexchange",
+           "exchange_type": "topic",
+           "binding_key": "b.b.#"}
+
+
+class test_MapRoute(unittest.TestCase):
+
+    @with_queues(foo=a_route, bar=b_route)
+    def test_route_for_task_expanded_route(self):
+        expand = E(conf.QUEUES)
+        route = routes.MapRoute({"celery.ping": "foo"})
+        self.assertDictContainsSubset(a_route,
+                             expand(route.route_for_task("celery.ping")))
+        self.assertIsNone(route.route_for_task("celery.awesome"))
+
+    @with_queues(foo=a_route, bar=b_route)
+    def test_route_for_task(self):
+        expand = E(conf.QUEUES)
+        route = routes.MapRoute({"celery.ping": b_route})
+        self.assertDictContainsSubset(b_route,
+                             expand(route.route_for_task("celery.ping")))
+        self.assertIsNone(route.route_for_task("celery.awesome"))
+
+    def test_expand_route_not_found(self):
+        expand = E(conf.QUEUES)
+        route = routes.MapRoute({"a": "x"})
+        self.assertRaises(RouteNotFound, expand, route.route_for_task("a"))
+
+
+class test_lookup_route(unittest.TestCase):
+
+    @with_queues(foo=a_route, bar=b_route)
+    def test_lookup_takes_first(self):
+        expand = E(conf.QUEUES)
+        R = routes.prepare(({"celery.ping": "bar"},
+                            {"celery.ping": "foo"}))
+        self.assertDictContainsSubset(b_route,
+                expand(routes.lookup_route(R, "celery.ping", gen_unique_id(),
+                    args=[1, 2], kwargs={})))
+
+    @with_queues(foo=a_route, bar=b_route)
+    def test_lookup_paths_traversed(self):
+        expand = E(conf.QUEUES)
+        R = routes.prepare(({"celery.xaza": "bar"},
+                            {"celery.ping": "foo"}))
+        self.assertDictContainsSubset(a_route,
+                expand(routes.lookup_route(R, "celery.ping", gen_unique_id(),
+                    args=[1, 2], kwargs={})))
+        self.assertIsNone(routes.lookup_route(R, "celery.poza"))
+
+
+class test_lookup_disabled(unittest.TestCase):
+
+    def test_disabled(self):
+
+        def create_router(name, is_disabled):
+            class _Router(object):
+
+                def disabled(self, task, *args):
+                    if task == name:
+                        return is_disabled
+            return _Router()
+
+
+        A = create_router("celery.ping", True)
+        B = create_router("celery.ping", False)
+        C = object()
+
+        R1 = (routes.prepare((A, B, C)), True)
+        R2 = (routes.prepare((B, C, A)), False)
+        R3 = (routes.prepare((C, A, B)), True)
+        R4 = (routes.prepare((B, A, C)), False)
+        R5 = (routes.prepare((A, C, B)), True)
+        R6 = (routes.prepare((C, B, A)), False)
+
+        for i, (router, state) in enumerate((R1, R2, R3, R4, R5, R6)):
+            self.assertEqual(routes.lookup_disabled(router, "celery.ping"),
+                             state, "ok %d" % i)
+
+