Explorar o código

Fixes backward compatibility for #2217

Ask Solem %!s(int64=9) %!d(string=hai) anos
pai
achega
16f927185d
Modificáronse 3 ficheiros con 28 adicións e 8 borrados
  1. 7 1
      celery/app/routes.py
  2. 10 0
      celery/tests/utils/test_functional.py
  3. 11 7
      celery/utils/functional.py

+ 7 - 1
celery/app/routes.py

@@ -19,11 +19,17 @@ from kombu import Queue
 from celery.exceptions import QueueNotFound
 from celery.five import items, string_t
 from celery.utils import lpmerge
-from celery.utils.functional import firstmethod, mlazy
+from celery.utils.functional import firstmethod, fun_takes_argument, mlazy
 from celery.utils.imports import instantiate
 
 __all__ = ['MapRoute', 'Router', 'prepare']
 
+
+def _try_route(meth, task, args, kwargs, options=None):
+    if fun_takes_argument('options', meth, position=4):
+        return meth(task, args, kwargs, options)
+    return meth(task, args, kwargs)
+
 _first_route = firstmethod('route_for_task')
 
 

+ 10 - 0
celery/tests/utils/test_functional.py

@@ -319,9 +319,19 @@ class test_fun_takes_argument(Case):
     def test_named(self):
         self.assertTrue(fun_takes_argument('foo', lambda a, foo, bar: 1))
 
+        def fun(a, b, c, d):
+            return 1
+
+        self.assertTrue(fun_takes_argument('foo', fun, position=4))
+
     def test_starargs(self):
         self.assertTrue(fun_takes_argument('foo', lambda a, *args: 1))
 
     def test_does_not(self):
         self.assertFalse(fun_takes_argument('foo', lambda a, bar, baz: 1))
         self.assertFalse(fun_takes_argument('foo', lambda: 1))
+
+        def fun(a, b, foo):
+            return 1
+
+        self.assertFalse(fun_takes_argument('foo', fun, position=4))

+ 11 - 7
celery/utils/functional.py

@@ -237,7 +237,7 @@ def first(predicate, it):
     )
 
 
-def firstmethod(method):
+def firstmethod(method, on_call=None):
     """Return a function that with a list of instances,
     finds the first instance that gives a value for the given method.
 
@@ -249,13 +249,14 @@ def firstmethod(method):
     def _matcher(it, *args, **kwargs):
         for obj in it:
             try:
-                answer = getattr(maybe_evaluate(obj), method)(*args, **kwargs)
+                meth = getattr(maybe_evaluate(obj), method)
+                reply = (on_call(meth, *args, **kwargs) if on_call
+                         else meth(*args, **kwargs))
             except AttributeError:
                 pass
             else:
-                if answer is not None:
-                    return answer
-
+                if reply is not None:
+                    return reply
     return _matcher
 
 
@@ -399,6 +400,9 @@ def head_from_fun(fun, bound=False, debug=False):
     return result
 
 
-def fun_takes_argument(name, fun):
+def fun_takes_argument(name, fun, position=None):
     spec = getfullargspec(fun)
-    return spec.keywords or spec.varargs or name in spec.args
+    return (
+        spec.keywords or spec.varargs or
+        (len(spec.args) >= position if position else name in spec.args)
+    )