Browse Source

Now checks task arguments when calling a task (.delay, .apply_async)

Ask Solem 10 years ago
parent
commit
3e91891c7b
3 changed files with 55 additions and 4 deletions
  1. 2 1
      celery/app/base.py
  2. 7 0
      celery/app/task.py
  3. 46 3
      celery/utils/functional.py

+ 2 - 1
celery/app/base.py

@@ -38,7 +38,7 @@ from celery.loaders import get_loader_cls
 from celery.local import PromiseProxy, maybe_evaluate
 from celery.utils import gen_task_name
 from celery.utils.dispatch import Signal
-from celery.utils.functional import first, maybe_list
+from celery.utils.functional import first, maybe_list, head_from_fun
 from celery.utils.imports import instantiate, symbol_by_name
 from celery.utils.objects import FallbackContext, mro_lookup
 
@@ -286,6 +286,7 @@ class Celery(object):
                 '_decorated': True,
                 '__doc__': fun.__doc__,
                 '__module__': fun.__module__,
+                '__header__': staticmethod(head_from_fun(fun)),
                 '__wrapped__': fun}, **options))()
             self._tasks[task.name] = task
             task.bind(self)  # connects task to this app

+ 7 - 0
celery/app/task.py

@@ -448,6 +448,13 @@ class Task(object):
             be replaced by a local :func:`apply` call instead.
 
         """
+        try:
+            check_arguments = self.__header__
+        except AttributeError:
+            pass
+        else:
+            check_arguments(*args or (), **kwargs or {})
+
         app = self._get_app()
         if app.conf.CELERY_ALWAYS_EAGER:
             return self.apply(args, kwargs, task_id=task_id or uuid(),

+ 46 - 3
celery/utils/functional.py

@@ -6,13 +6,14 @@
     Utilities for functions.
 
 """
-from __future__ import absolute_import
+from __future__ import absolute_import, print_function
 
 import sys
 import threading
 
 from collections import OrderedDict
-from functools import wraps
+from functools import partial, wraps
+from inspect import getargspec, isfunction, ismethod
 from itertools import islice
 
 from kombu.utils import cached_property
@@ -22,10 +23,15 @@ from celery.five import UserDict, UserList, items, keys
 
 __all__ = ['LRUCache', 'is_list', 'maybe_list', 'memoize', 'mlazy', 'noop',
            'first', 'firstmethod', 'chunks', 'padlist', 'mattrgetter', 'uniq',
-           'regen', 'dictfilter', 'lazy', 'maybe_evaluate']
+           'regen', 'dictfilter', 'lazy', 'maybe_evaluate', 'head_from_fun']
 
 KEYWORD_MARK = object()
 
+FUNHEAD_TEMPLATE = """
+def {fun_name}({fun_args}):
+    return {fun_value}
+"""
+
 
 class LRUCache(UserDict):
     """LRU Cache implementation using a doubly linked list to track access.
@@ -302,3 +308,40 @@ def dictfilter(d=None, **kw):
     """Remove all keys from dict ``d`` whose value is :const:`None`"""
     d = kw if d is None else (dict(d, **kw) if kw else d)
     return {k: v for k, v in items(d) if v is not None}
+
+
+def _argsfromspec(spec, replace_defaults=True):
+    if spec.defaults:
+        split = len(spec.defaults)
+        defaults = (list(range(len(spec.defaults))) if replace_defaults
+                    else spec.defaults)
+        positional = spec.args[:-split]
+        optional = list(zip(spec.args[-split:], defaults))
+    else:
+        positional, optional = spec.args, []
+    return ', '.join(filter(None, [
+        ', '.join(positional),
+        ', '.join('{0}={1}'.format(k, v) for k, v in optional),
+        '*{0}'.format(spec.varargs) if spec.varargs else None,
+        '**{0}'.format(spec.keywords) if spec.keywords else None,
+    ]))
+
+
+def head_from_fun(fun, debug=True):
+    if not isfunction(fun) and hasattr(fun, '__call__'):
+        name, fun = fun.__class__.__name__, fun.__call__
+    else:
+        name = fun.__name__
+    spec = getargspec(fun)
+    definition = FUNHEAD_TEMPLATE.format(
+        fun_name=name,
+        fun_args=_argsfromspec(getargspec(fun)),
+        fun_value=1,
+    )
+    if debug:
+        print(definition, file=sys.stderr)
+    namespace = {'__name__': 'headof_{0}'.format(name)}
+    exec(definition, namespace)
+    result = namespace[name]
+    result._source = definition
+    return result