Browse Source

Adds app.current_worker_task property. Closes #2100

Ask Solem 9 years ago
parent
commit
a4a5c2a794
4 changed files with 17 additions and 13 deletions
  1. 7 3
      celery/app/base.py
  2. 2 2
      celery/app/builtins.py
  3. 2 2
      celery/canvas.py
  4. 6 6
      celery/tests/app/test_builtins.py

+ 7 - 3
celery/app/base.py

@@ -635,12 +635,12 @@ class Celery(object):
         options = router.route(options, route_name or name, args, kwargs)
 
         if root_id is None:
-            parent, have_parent = get_current_worker_task(), True
+            parent, have_parent = self.current_worker_task, True
             if parent:
                 root_id = parent.request.root_id or parent.request.id
         if parent_id is None:
             if not have_parent:
-                parent, have_parent = get_current_worker_task(), True
+                parent, have_parent = self.current_worker_task, True
             if parent:
                 parent_id = parent.request.id
 
@@ -661,7 +661,7 @@ class Celery(object):
         result = (result_cls or self.AsyncResult)(task_id)
         if add_to_parent:
             if not have_parent:
-                parent, have_parent = get_current_worker_task(), True
+                parent, have_parent = self.current_worker_task, True
             if parent:
                 parent.add_trail(result)
         return result
@@ -1025,6 +1025,10 @@ class Celery(object):
         :const:`None`."""
         return _task_stack.top
 
+    @property
+    def current_worker_task(self):
+        return get_current_worker_task()
+
     @cached_property
     def oid(self):
         return oid_from(self)

+ 2 - 2
celery/app/builtins.py

@@ -9,7 +9,7 @@
 """
 from __future__ import absolute_import
 
-from celery._state import get_current_worker_task, connect_on_app_finalize
+from celery._state import connect_on_app_finalize
 from celery.utils.log import get_logger
 
 __all__ = []
@@ -157,7 +157,7 @@ def add_group_task(app):
         with app.producer_or_acquire() as producer:
             [stask.apply_async(group_id=group_id, producer=producer,
                                add_to_parent=False) for stask in taskit]
-        parent = get_current_worker_task()
+        parent = app.current_worker_task
         if add_to_parent and parent:
             parent.add_trail(result)
         return result

+ 2 - 2
celery/canvas.py

@@ -22,7 +22,7 @@ from itertools import chain as _chain
 
 from kombu.utils import cached_property, fxrange, reprcall, uuid
 
-from celery._state import current_app, get_current_worker_task
+from celery._state import current_app
 from celery.local import try_import
 from celery.result import GroupResult
 from celery.utils import abstract
@@ -761,7 +761,7 @@ class group(Signature):
         if len(result) == 1 and isinstance(result[0], GroupResult):
             result = result[0]
 
-        parent_task = get_current_worker_task()
+        parent_task = app.current_worker_task
         if add_to_parent and parent_task:
             parent_task.add_trail(result)
         return result

+ 6 - 6
celery/tests/app/test_builtins.py

@@ -111,21 +111,21 @@ class test_group(BuiltinsCase):
             task.clone.attach_mock(Mock(), 'apply_async')
         return g, result
 
-    @patch('celery.app.builtins.get_current_worker_task')
-    def test_task(self, get_current_worker_task):
+    @patch('celery.app.base.Celery.current_worker_task')
+    def test_task(self, current_worker_task):
         g, result = self.mock_group(self.add.s(2), self.add.s(4))
         self.task(g.tasks, result, result.id, (2,)).results
         g.tasks[0].clone().apply_async.assert_called_with(
             group_id=result.id, producer=self.app.producer_or_acquire(),
             add_to_parent=False,
         )
-        get_current_worker_task().add_trail.assert_called_with(result)
+        current_worker_task.add_trail.assert_called_with(result)
 
-    @patch('celery.app.builtins.get_current_worker_task')
-    def test_task__disable_add_to_parent(self, get_current_worker_task):
+    @patch('celery.app.base.Celery.current_worker_task')
+    def test_task__disable_add_to_parent(self, current_worker_task):
         g, result = self.mock_group(self.add.s(2, 2), self.add.s(4, 4))
         self.task(g.tasks, result, result.id, None, add_to_parent=False)
-        self.assertFalse(get_current_worker_task().add_trail.called)
+        self.assertFalse(current_worker_task.add_trail.called)
 
 
 class test_chain(BuiltinsCase):