Ver código fonte

Makes inspect.getargs(@task(x)) work. Issue #1833

Ask Solem 11 anos atrás
pai
commit
a0057bed68
3 arquivos alterados com 20 adições e 3 exclusões
  1. 4 1
      celery/app/base.py
  2. 1 1
      celery/datastructures.py
  3. 15 1
      celery/utils/__init__.py

+ 4 - 1
celery/app/base.py

@@ -33,6 +33,7 @@ from celery.exceptions import AlwaysEagerIgnored, ImproperlyConfigured
 from celery.five import items, values
 from celery.loaders import get_loader_cls
 from celery.local import PromiseProxy, maybe_evaluate
+from celery.utils import shadowsig
 from celery.utils.functional import first, maybe_list
 from celery.utils.imports import instantiate, symbol_by_name
 from celery.utils.objects import mro_lookup
@@ -235,7 +236,9 @@ class Celery(object):
             'run': fun if bind else staticmethod(fun),
             '_decorated': True,
             '__doc__': fun.__doc__,
-            '__module__': fun.__module__}, **options))()
+            '__module__': fun.__module__,
+            '__wrapped__': fun}, **options))()
+        shadowsig(T, fun)  # for inspect.getargspec
         task = self._tasks[T.name]  # return global instance.
         return task
 

+ 1 - 1
celery/datastructures.py

@@ -555,7 +555,7 @@ class LimitedSet(object):
     """Kind-of Set with limitations.
 
     Good for when you need to test for membership (`a in set`),
-    but the list might become to big.
+    but the list might become too big.
 
     :keyword maxlen: Maximum number of members before we start
                      evicting expired members.

+ 15 - 1
celery/utils/__init__.py

@@ -16,7 +16,7 @@ import warnings
 import datetime
 
 from functools import partial, wraps
-from inspect import getargspec
+from inspect import getargspec, ismethod
 from pprint import pprint
 
 from kombu.entity import Exchange, Queue
@@ -29,6 +29,8 @@ __all__ = ['worker_direct', 'warn_deprecated', 'deprecated', 'lpmerge',
            'jsonify', 'gen_task_name', 'nodename', 'nodesplit',
            'cached_property']
 
+PY3 = sys.version_info[0] == 3
+
 
 PENDING_DEPRECATION_FMT = """
     {description} is scheduled for deprecation in \
@@ -341,6 +343,18 @@ def default_nodename(hostname):
     name, host = nodesplit(hostname or '')
     return nodename(name or NODENAME_DEFAULT, host or socket.gethostname())
 
+
+def shadowsig(wrapper, wrapped):
+    if ismethod(wrapped):
+        wrapped = wrapped.__func__
+    wrapper.__code__ = wrapped.__code__
+    wrapper.__defaults__ = wrapper.func_defaults = wrapped.__defaults__
+
+    if not PY3:
+        wrapper.func_code = wrapper.__code__
+        wrapper.func_defaults = wrapper.__defaults__
+
+
 # ------------------------------------------------------------------------ #
 # > XXX Compat
 from .log import LOG_LEVELS     # noqa