Prechádzať zdrojové kódy

Optimize __call__ at execution not class creation.

Moves __call__ optimization so that it happens when the task is executed,
rather than when the class is created.

This may fix a bug where custom __call__'s mysteriously disappears.
Ask Solem 13 rokov pred
rodič
commit
93be572862
2 zmenil súbory, kde vykonal 34 pridanie a 22 odobranie
  1. 6 21
      celery/app/task/__init__.py
  2. 28 1
      celery/execute/trace.py

+ 6 - 21
celery/app/task/__init__.py

@@ -19,7 +19,6 @@ from ... import current_app
 from ... import states
 from ...datastructures import ExceptionInfo
 from ...exceptions import MaxRetriesExceededError, RetryTaskError
-from ...execute.trace import eager_trace_task
 from ...result import EagerResult
 from ...utils import (fun_takes_kwargs, instantiate,
                       mattrgetter, uuid, maybe_reraise)
@@ -90,10 +89,6 @@ class TaskType(type):
         new = super(TaskType, cls).__new__
         task_module = attrs.get("__module__") or "__main__"
 
-        if "__call__" in attrs:
-            # see note about __call__ below.
-            attrs["__defines_call__"] = True
-
         # In old Celery the @task decorator didn't exist, so one would create
         # classes instead and use them directly (e.g. MyTask.apply_async()).
         # the use of classmethods was a hack so that it was not necessary
@@ -129,22 +124,6 @@ class TaskType(type):
             attrs["name"] = '.'.join([module_name, name])
             autoname = True
 
-        # - Automatically generate __call__.
-        # If this or none of its bases define __call__, we simply
-        # alias it to the ``run`` method, as
-        # this means we can skip a stacktrace frame :)
-        if not (attrs.get("__call__")
-                or any(getattr(b, "__defines_call__", False) for b in bases)):
-            try:
-                attrs["__call__"] = attrs["run"]
-            except KeyError:
-
-                # the class does not yet define run,
-                # so we can't optimize this case.
-                def __call__(self, *args, **kwargs):
-                    return self.run(*args, **kwargs)
-                attrs["__call__"] = __call__
-
         # - Create and register class.
         # Because of the way import happens (recursively)
         # we may or may not be the first time the task tries to register
@@ -368,6 +347,9 @@ class BaseTask(object):
     def __reduce__(self):
         return (_unpickle_task, (self.name, ), None)
 
+    def __call__(self, *args, **kwargs):
+        return self.run(*args, **kwargs)
+
     def run(self, *args, **kwargs):
         """The body of the task executed by workers."""
         raise NotImplementedError("Tasks must define the run method.")
@@ -665,6 +647,9 @@ class BaseTask(object):
         :rtype :class:`celery.result.EagerResult`:
 
         """
+        # trace imports BaseTask, so need to import inline.
+        from ...execute.trace import eager_trace_task
+
         app = self._get_app()
         args = args or []
         kwargs = kwargs or {}

+ 28 - 1
celery/execute/trace.py

@@ -28,6 +28,7 @@ from warnings import warn
 from .. import app as app_module
 from .. import current_app
 from .. import states, signals
+from ..app.task import BaseTask
 from ..datastructures import ExceptionInfo
 from ..exceptions import RetryTaskError
 from ..utils.serialization import get_pickleable_exception
@@ -43,6 +44,27 @@ FAILURE = states.FAILURE
 EXCEPTION_STATES = states.EXCEPTION_STATES
 
 
+def mro_lookup(cls, attr, stop=()):
+    """Returns the first node by MRO order that defines an attribute.
+
+    :keyword stop: A list of types that if reached will stop the search.
+
+    :returns None: if the attribute was not found.
+
+    """
+    for node in cls.mro():
+        if node in stop:
+            return
+        if attr in node.__dict__:
+            return node
+
+
+def defines_custom_call(task):
+    """Returns true if the task or one of its bases
+    defines __call__ (excluding the one in BaseTask)."""
+    return mro_lookup(task.__class__, "__call__", stop=(BaseTask, object))
+
+
 class TraceInfo(object):
     __slots__ = ("state", "retval", "exc_info",
                  "exc_type", "exc_value", "tb", "strtb")
@@ -106,6 +128,11 @@ class TraceInfo(object):
 
 def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
         Info=TraceInfo, eager=False, propagate=False):
+    # If the task doesn't define a custom __call__ method
+    # we optimize it away by simply calling the run method directly,
+    # saving the extra method call and a line less in the stack trace.
+    fun = task if defines_custom_call(task) else task.run
+
     task = task or current_app.tasks[name]
     loader = loader or current_app.loader
     backend = task.backend
@@ -149,7 +176,7 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
 
                 # -*- TRACE -*-
                 try:
-                    R = retval = task(*args, **kwargs)
+                    R = retval = fun(*args, **kwargs)
                     state, einfo = SUCCESS, None
                 except RetryTaskError, exc:
                     I = Info(RETRY, exc, sys.exc_info())