Browse Source

Adds celery.task.current for the currently executing task

Ask Solem 13 years ago
parent
commit
7bc3975b7f
4 changed files with 24 additions and 9 deletions
  1. 16 6
      celery/app/__init__.py
  2. 4 0
      celery/execute/trace.py
  3. 4 1
      celery/task/__init__.py
  4. 0 2
      examples/resultgraph/tasks.py

+ 16 - 6
celery/app/__init__.py

@@ -21,11 +21,16 @@ from ..utils import cached_property, instantiate
 from . import annotations
 from . import base
 
-# Apps with the :attr:`~celery.app.base.BaseApp.set_as_current` attribute
-# sets this, so it will always contain the last instantiated app,
-# and is the default app returned by :func:`app_or_default`.
-_tls = threading.local()
-_tls.current_app = None
+
+class _TLS(threading.local):
+    #: Apps with the :attr:`~celery.app.base.BaseApp.set_as_current` attribute
+    #: sets this, so it will always contain the last instantiated app,
+    #: and is the default app returned by :func:`app_or_default`.
+    current_app = None
+
+    #: The currently executing task.
+    current_task = None
+_tls = _TLS()
 
 
 class AppPickler(object):
@@ -228,13 +233,18 @@ default_loader = os.environ.get("CELERY_LOADER") or "default"
 
 #: Global fallback app instance.
 default_app = App("default", loader=default_loader,
-                  set_as_current=False, accept_magic_kwargs=True)
+                             set_as_current=False,
+                             accept_magic_kwargs=True)
 
 
 def current_app():
     return getattr(_tls, "current_app", None) or default_app
 
 
+def current_task():
+    return getattr(_tls, "current_task", None)
+
+
 def _app_or_default(app=None):
     """Returns the app provided or the default app if none.
 

+ 4 - 0
celery/execute/trace.py

@@ -25,6 +25,7 @@ import traceback
 
 from warnings import warn
 
+from .. import app as app_module
 from .. import current_app
 from .. import states, signals
 from ..datastructures import ExceptionInfo
@@ -121,6 +122,7 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
     task_on_success = task.on_success
     task_after_return = task.after_return
     task_request = task.request
+    _tls = app_module._tls
 
     store_result = backend.store_result
     backend_cleanup = backend.process_cleanup
@@ -134,6 +136,7 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
     def trace_task(uuid, args, kwargs, request=None):
         R = I = None
         try:
+            _tls.current_task = task
             update_request(request or {}, args=args,
                            called_directly=False, kwargs=kwargs)
             try:
@@ -181,6 +184,7 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
                 send_postrun(sender=task, task_id=uuid, task=task,
                             args=args, kwargs=kwargs, retval=retval)
             finally:
+                _tls.current_task = None
                 clear_request()
                 if not eager:
                     try:

+ 4 - 1
celery/task/__init__.py

@@ -11,13 +11,16 @@
 """
 from __future__ import absolute_import
 
-from ..app import app_or_default
+from ..app import app_or_default, current_task as _current_task
+from ..local import Proxy
 
 from .base import Task, PeriodicTask        # noqa
 from .sets import group, TaskSet, subtask   # noqa
 from .chords import chord                   # noqa
 from .control import discard_all            # noqa
 
+current = Proxy(_current_task)
+
 
 def task(*args, **kwargs):
     """Decorator to create a task class out of any callable.

+ 0 - 2
examples/resultgraph/tasks.py

@@ -97,5 +97,3 @@ class chord2(object):
     def __call__(self, body, **options):
         body.options.setdefault("task_id", uuid())
         unlock_graph.apply_async()
-
-