Explorar el Código

Refactored celery.routes

Ask Solem hace 15 años
padre
commit
c3efa92afe
Se han modificado 5 ficheros con 68 adiciones y 99 borrados
  1. 4 8
      celery/execute/__init__.py
  2. 0 1
      celery/messaging.py
  3. 41 53
      celery/routes.py
  4. 8 37
      celery/tests/test_routes.py
  5. 15 0
      celery/utils/__init__.py

+ 4 - 8
celery/execute/__init__.py

@@ -5,7 +5,7 @@ from celery.messaging import with_connection
 from celery.messaging import TaskPublisher
 from celery.registry import tasks
 from celery.result import AsyncResult, EagerResult
-from celery.routes import route
+from celery.routes import Router
 from celery.utils import gen_unique_id, fun_takes_kwargs, mattrgetter
 
 extract_exec_options = mattrgetter("queue", "routing_key", "exchange",
@@ -17,7 +17,7 @@ extract_exec_options = mattrgetter("queue", "routing_key", "exchange",
 @with_connection
 def apply_async(task, args=None, kwargs=None, countdown=None, eta=None,
         task_id=None, publisher=None, connection=None, connect_timeout=None,
-        routes=None, queues=None, **options):
+        router=None, **options):
     """Run a task asynchronously by the celery daemon(s).
 
     :param task: The :class:`~celery.task.base.Task` to run.
@@ -79,10 +79,7 @@ def apply_async(task, args=None, kwargs=None, countdown=None, eta=None,
     replaced by a local :func:`apply` call instead.
 
     """
-    if routes is None:
-        routes = conf.ROUTES
-    if queues is None:
-        queues = conf.get_queues()
+    router = router or Router(conf.ROUTES, conf.get_queues())
 
     if conf.ALWAYS_EAGER:
         return apply(task, args, kwargs, task_id=task_id)
@@ -90,8 +87,7 @@ def apply_async(task, args=None, kwargs=None, countdown=None, eta=None,
     task = tasks[task.name] # get instance from registry
 
     options = dict(extract_exec_options(task), **options)
-    options = route(routes, options, queues,
-                    task.name, args, kwargs)
+    options = router.route(options, task.name, args, kwargs)
     exchange = options.get("exchange")
     exchange_type = options.get("exchange_type")
 

+ 0 - 1
celery/messaging.py

@@ -14,7 +14,6 @@ from celery import conf
 from celery import signals
 from celery.utils import gen_unique_id, mitemgetter, noop
 from celery.utils.functional import wraps
-from celery.routes import lookup_route, expand_destination
 from celery.loaders import load_settings
 
 

+ 41 - 53
celery/routes.py

@@ -1,5 +1,7 @@
 from celery.exceptions import RouteNotFound
-from celery.utils import instantiate
+from celery.utils import instantiate, firstmethod
+
+_first_route = firstmethod("route_for_task")
 
 
 class MapRoute(object):
@@ -12,27 +14,46 @@ class MapRoute(object):
         return self.map.get(task)
 
 
-def expand_destination(route, queues):
-    # The route can simply be a queue name,
-    # this is convenient for direct exchanges.
-    if isinstance(route, basestring):
-        queue, route = route, {}
-    else:
-        # For topic exchanges you can use the defaults from a queue
-        # definition, and override e.g. just the routing_key.
-        queue = route.pop("queue", None)
+class Router(object):
+
+    def __init__(self, routes, queues):
+        self.queues = queues
+        self.routes = routes
+
+    def route(self, options, task, args=(), kwargs={}):
+        # Expand "queue" keys in options.
+        options = self.expand_destination(options)
+        if self.routes:
+            route = self.lookup_route(task, args, kwargs)
+            if route:
+                # Also expand "queue" keys in route.
+                return dict(options, **self.expand_destination(route))
+        return options
+
+    def expand_destination(self, route):
+        # The route can simply be a queue name,
+        # this is convenient for direct exchanges.
+        if isinstance(route, basestring):
+            queue, route = route, {}
+        else:
+            # For topic exchanges you can use the defaults from a queue
+            # definition, and override e.g. just the routing_key.
+            queue = route.pop("queue", None)
+
+        if queue:
+            try:
+                dest = dict(self.queues[queue])
+            except KeyError:
+                raise RouteNotFound(
+                    "Route %s does not exist in the routing table "
+                    "(CELERY_QUEUES)" % route)
+            dest.setdefault("routing_key", dest.get("binding_key"))
+            return dict(route, **dest)
 
-    if queue:
-        try:
-            dest = dict(queues[queue])
-        except KeyError:
-            raise RouteNotFound(
-                "Route %s does not exist in the routing table "
-                "(CELERY_QUEUES)" % route)
-        dest.setdefault("routing_key", dest.get("binding_key"))
-        return dict(route, **dest)
+        return route
 
-    return route
+    def lookup_route(self, task, args=None, kwargs=None):
+        return _first_route(self.routes, task, args, kwargs)
 
 
 def prepare(routes):
@@ -50,38 +71,5 @@ def prepare(routes):
     return map(expand_route, routes)
 
 
-def route(routes, options, queues, task, args=(), kwargs={}):
-    # Expand "queue" keys in options.
-    options = expand_destination(options, queues)
-    if routes:
-        route = lookup_route(routes, task, args, kwargs)
-        # Also expand "queue" keys in route.
-        return dict(options, **expand_destination(route, queues))
-    return options
-
-
-def firstmatcher(method):
-    """Returns a functions that with a list of instances,
-    finds the first instance that returns a value for the given method."""
-
-    def _matcher(seq, *args, **kwargs):
-        for cls in seq:
-            try:
-                answer = getattr(cls, method)(*args, **kwargs)
-                if answer is not None:
-                    return answer
-            except AttributeError:
-                pass
-    return _matcher
-
-
-_first_route = firstmatcher("route_for_task")
-_first_disabled = firstmatcher("disabled")
-
-
-def lookup_route(routes, task, args=None, kwargs=None):
-    return _first_route(routes, task, args, kwargs)
 
 
-def lookup_disabled(routes, task, args=None, kwargs=None):
-    return _first_disabled(routes, task, args, kwargs)

+ 8 - 37
celery/tests/test_routes.py

@@ -9,7 +9,7 @@ from celery.exceptions import RouteNotFound
 
 def E(queues):
     def expand(answer):
-        return routes.expand_destination(answer, queues)
+        return routes.Router([], queues).expand_destination(answer)
     return expand
 
 
@@ -64,48 +64,19 @@ class test_lookup_route(unittest.TestCase):
 
     @with_queues(foo=a_queue, bar=b_queue)
     def test_lookup_takes_first(self):
-        expand = E(conf.QUEUES)
         R = routes.prepare(({"celery.ping": "bar"},
                             {"celery.ping": "foo"}))
+        router = routes.Router(R, conf.QUEUES)
         self.assertDictContainsSubset(b_queue,
-                expand(routes.lookup_route(R, "celery.ping",
-                    args=[1, 2], kwargs={})))
+                router.route({}, "celery.ping",
+                    args=[1, 2], kwargs={}))
 
     @with_queues(foo=a_queue, bar=b_queue)
     def test_lookup_paths_traversed(self):
-        expand = E(conf.QUEUES)
         R = routes.prepare(({"celery.xaza": "bar"},
                             {"celery.ping": "foo"}))
+        router = routes.Router(R, conf.QUEUES)
         self.assertDictContainsSubset(a_queue,
-                expand(routes.lookup_route(R, "celery.ping",
-                    args=[1, 2], kwargs={})))
-        self.assertIsNone(routes.lookup_route(R, "celery.poza"))
-
-
-class test_lookup_disabled(unittest.TestCase):
-
-    def test_disabled(self):
-
-        def create_router(name, is_disabled):
-            class _Router(object):
-
-                def disabled(self, task, *args):
-                    if task == name:
-                        return is_disabled
-            return _Router()
-
-
-        A = create_router("celery.ping", True)
-        B = create_router("celery.ping", False)
-        C = object()
-
-        R1 = (routes.prepare((A, B, C)), True)
-        R2 = (routes.prepare((B, C, A)), False)
-        R3 = (routes.prepare((C, A, B)), True)
-        R4 = (routes.prepare((B, A, C)), False)
-        R5 = (routes.prepare((A, C, B)), True)
-        R6 = (routes.prepare((C, B, A)), False)
-
-        for i, (router, state) in enumerate((R1, R2, R3, R4, R5, R6)):
-            self.assertEqual(routes.lookup_disabled(router, "celery.ping"),
-                             state, "ok %d" % i)
+                router.route({}, "celery.ping",
+                    args=[1, 2], kwargs={}))
+        self.assertEqual(router.route({}, "celery.poza"), {})

+ 15 - 0
celery/utils/__init__.py

@@ -50,6 +50,21 @@ def first(predicate, iterable):
             return item
 
 
+def firstmethod(method):
+    """Returns a functions that with a list of instances,
+    finds the first instance that returns a value for the given method."""
+
+    def _matcher(seq, *args, **kwargs):
+        for cls in seq:
+            try:
+                answer = getattr(cls, method)(*args, **kwargs)
+                if answer is not None:
+                    return answer
+            except AttributeError:
+                pass
+    return _matcher
+
+
 def chunks(it, n):
     """Split an iterator into chunks with ``n`` elements each.