فهرست منبع

Add check task-synchronous-subtasks for task_always_eager mode (#4322)

Denis Podlesniy 8 سال پیش
والد
کامیت
d02d260a19
3فایلهای تغییر یافته به همراه42 افزوده شده و 10 حذف شده
  1. 4 3
      celery/app/task.py
  2. 29 7
      celery/result.py
  3. 9 0
      t/unit/tasks/test_result.py

+ 4 - 3
celery/app/task.py

@@ -14,7 +14,7 @@ from celery.canvas import signature
 from celery.exceptions import Ignore, MaxRetriesExceededError, Reject, Retry
 from celery.exceptions import Ignore, MaxRetriesExceededError, Reject, Retry
 from celery.five import items, python_2_unicode_compatible
 from celery.five import items, python_2_unicode_compatible
 from celery.local import class_property
 from celery.local import class_property
-from celery.result import EagerResult
+from celery.result import EagerResult, denied_join_result
 from celery.utils import abstract
 from celery.utils import abstract
 from celery.utils.functional import mattrgetter, maybe_list
 from celery.utils.functional import mattrgetter, maybe_list
 from celery.utils.imports import instantiate
 from celery.utils.imports import instantiate
@@ -521,8 +521,9 @@ class Task(object):
 
 
         app = self._get_app()
         app = self._get_app()
         if app.conf.task_always_eager:
         if app.conf.task_always_eager:
-            return self.apply(args, kwargs, task_id=task_id or uuid(),
-                              link=link, link_error=link_error, **options)
+            with denied_join_result():
+                return self.apply(args, kwargs, task_id=task_id or uuid(),
+                                  link=link, link_error=link_error, **options)
         # add 'self' if this is a "task_method".
         # add 'self' if this is a "task_method".
         if self.__self__ is not None:
         if self.__self__ is not None:
             args = args if isinstance(args, tuple) else tuple(args or ())
             args = args if isinstance(args, tuple) else tuple(args or ())

+ 29 - 7
celery/result.py

@@ -51,6 +51,16 @@ def allow_join_result():
         _set_task_join_will_block(reset_value)
         _set_task_join_will_block(reset_value)
 
 
 
 
+@contextmanager
+def denied_join_result():
+    reset_value = task_join_will_block()
+    _set_task_join_will_block(True)
+    try:
+        yield
+    finally:
+        _set_task_join_will_block(reset_value)
+
+
 class ResultBase(object):
 class ResultBase(object):
     """Base class for results."""
     """Base class for results."""
 
 
@@ -617,7 +627,8 @@ class ResultSet(ResultBase):
                 raise TimeoutError('The operation timed out')
                 raise TimeoutError('The operation timed out')
 
 
     def get(self, timeout=None, propagate=True, interval=0.5,
     def get(self, timeout=None, propagate=True, interval=0.5,
-            callback=None, no_ack=True, on_message=None):
+            callback=None, no_ack=True, on_message=None,
+            disable_sync_subtasks=True):
         """See :meth:`join`.
         """See :meth:`join`.
 
 
         This is here for API compatibility with :class:`AsyncResult`,
         This is here for API compatibility with :class:`AsyncResult`,
@@ -629,11 +640,12 @@ class ResultSet(ResultBase):
         return (self.join_native if self.supports_native_join else self.join)(
         return (self.join_native if self.supports_native_join else self.join)(
             timeout=timeout, propagate=propagate,
             timeout=timeout, propagate=propagate,
             interval=interval, callback=callback, no_ack=no_ack,
             interval=interval, callback=callback, no_ack=no_ack,
-            on_message=on_message,
+            on_message=on_message, disable_sync_subtasks=disable_sync_subtasks
         )
         )
 
 
     def join(self, timeout=None, propagate=True, interval=0.5,
     def join(self, timeout=None, propagate=True, interval=0.5,
-             callback=None, no_ack=True, on_message=None, on_interval=None):
+             callback=None, no_ack=True, on_message=None,
+             disable_sync_subtasks=True, on_interval=None):
         """Gather the results of all tasks as a list in order.
         """Gather the results of all tasks as a list in order.
 
 
         Note:
         Note:
@@ -669,13 +681,17 @@ class ResultSet(ResultBase):
             no_ack (bool): Automatic message acknowledgment (Note that if this
             no_ack (bool): Automatic message acknowledgment (Note that if this
                 is set to :const:`False` then the messages
                 is set to :const:`False` then the messages
                 *will not be acknowledged*).
                 *will not be acknowledged*).
+            disable_sync_subtasks (bool): Disable tasks to wait for sub tasks
+                this is the default configuration. CAUTION do not enable this
+                unless you must.
 
 
         Raises:
         Raises:
             celery.exceptions.TimeoutError: if ``timeout`` isn't
             celery.exceptions.TimeoutError: if ``timeout`` isn't
                 :const:`None` and the operation takes longer than ``timeout``
                 :const:`None` and the operation takes longer than ``timeout``
                 seconds.
                 seconds.
         """
         """
-        assert_will_not_block()
+        if disable_sync_subtasks:
+            assert_will_not_block()
         time_start = monotonic()
         time_start = monotonic()
         remaining = None
         remaining = None
 
 
@@ -723,7 +739,8 @@ class ResultSet(ResultBase):
 
 
     def join_native(self, timeout=None, propagate=True,
     def join_native(self, timeout=None, propagate=True,
                     interval=0.5, callback=None, no_ack=True,
                     interval=0.5, callback=None, no_ack=True,
-                    on_message=None, on_interval=None):
+                    on_message=None, on_interval=None,
+                    disable_sync_subtasks=True):
         """Backend optimized version of :meth:`join`.
         """Backend optimized version of :meth:`join`.
 
 
         .. versionadded:: 2.2
         .. versionadded:: 2.2
@@ -734,7 +751,8 @@ class ResultSet(ResultBase):
         This is currently only supported by the amqp, Redis and cache
         This is currently only supported by the amqp, Redis and cache
         result backends.
         result backends.
         """
         """
-        assert_will_not_block()
+        if disable_sync_subtasks:
+            assert_will_not_block()
         order_index = None if callback else {
         order_index = None if callback else {
             result.id: i for i, result in enumerate(self.results)
             result.id: i for i, result in enumerate(self.results)
         }
         }
@@ -916,7 +934,11 @@ class EagerResult(AsyncResult):
     def ready(self):
     def ready(self):
         return True
         return True
 
 
-    def get(self, timeout=None, propagate=True, **kwargs):
+    def get(self, timeout=None, propagate=True,
+            disable_sync_subtasks=True, **kwargs):
+        if disable_sync_subtasks:
+            assert_will_not_block()
+
         if self.successful():
         if self.successful():
             return self.result
             return self.result
         elif self.state in states.PROPAGATE_STATES:
         elif self.state in states.PROPAGATE_STATES:

+ 9 - 0
t/unit/tasks/test_result.py

@@ -872,6 +872,15 @@ class test_EagerResult:
         res = self.raising.apply(args=[3, 3])
         res = self.raising.apply(args=[3, 3])
         assert not res.revoke()
         assert not res.revoke()
 
 
+    @patch('celery.result.task_join_will_block')
+    def test_get_sync_subtask_option(self, task_join_will_block):
+        task_join_will_block.return_value = True
+        tid = uuid()
+        res_subtask_async = EagerResult(tid, 'x', 'x', states.SUCCESS)
+        with pytest.raises(RuntimeError):
+            res_subtask_async.get()
+        res_subtask_async.get(disable_sync_subtasks=False)
+
 
 
 class test_tuples:
 class test_tuples: