Pārlūkot izejas kodu

Raise RuntimeError if task calls result.get() and it will block

Ask Solem 11 gadi atpakaļ
vecāks
revīzija
4a9c9705d1

+ 11 - 0
celery/_state.py

@@ -28,6 +28,17 @@ default_app = None
 #: List of all app instances (weakrefs), must not be used directly.
 _apps = set()
 
+_task_join_will_block = False
+
+
+def _set_task_join_will_block(blocks):
+    global _task_join_will_block
+    _task_join_will_block = True
+
+
+def task_join_will_block():
+    return _task_join_will_block
+
 
 class _TLS(threading.local):
     #: Apps with the :attr:`~celery.app.base.BaseApp.set_as_current` attribute

+ 2 - 0
celery/concurrency/base.py

@@ -66,6 +66,8 @@ class BasePool(object):
     #: only used by multiprocessing pool
     uses_semaphore = False
 
+    task_join_will_block = True
+
     def __init__(self, limit=None, putlocks=True,
                  forking_enable=True, callbacks_propagate=(), **options):
         self.limit = limit

+ 1 - 0
celery/concurrency/eventlet.py

@@ -110,6 +110,7 @@ class TaskPool(base.BasePool):
 
     signal_safe = False
     is_green = True
+    task_join_will_block = False
 
     def __init__(self, *args, **kwargs):
         from eventlet import greenthread

+ 1 - 0
celery/concurrency/gevent.py

@@ -96,6 +96,7 @@ class TaskPool(BasePool):
 
     signal_safe = False
     is_green = True
+    task_join_will_block = False
 
     def __init__(self, *args, **kwargs):
         from gevent import spawn_raw

+ 2 - 1
celery/concurrency/prefork.py

@@ -16,7 +16,7 @@ from billiard.pool import RUN, CLOSE, Pool as BlockingPool
 
 from celery import platforms
 from celery import signals
-from celery._state import set_default_app
+from celery._state import set_default_app, _set_task_join_will_block
 from celery.app import trace
 from celery.concurrency.base import BasePool
 from celery.five import items
@@ -53,6 +53,7 @@ def process_initializer(app, hostname):
     logging works.
 
     """
+    _set_task_join_will_block(True)
     platforms.signals.reset(*WORKER_SIGRESET)
     platforms.signals.ignore(*WORKER_SIGIGNORE)
     platforms.set_mp_process_title('celeryd', hostname=hostname)

+ 15 - 0
celery/result.py

@@ -18,6 +18,7 @@ from kombu.utils.compat import OrderedDict
 
 from . import current_app
 from . import states
+from ._state import task_join_will_block
 from .app import app_or_default
 from .datastructures import DependencyGraph, GraphFormatter
 from .exceptions import IncompleteStream, TimeoutError
@@ -26,6 +27,17 @@ from .five import items, range, string_t, monotonic
 __all__ = ['ResultBase', 'AsyncResult', 'ResultSet', 'GroupResult',
            'EagerResult', 'result_from_tuple']
 
+E_WOULDBLOCK = """\
+Never call result.get() within a task!
+See http://docs.celeryq.org/en/latest/userguide/tasks.html\
+#task-synchronous-subtasks
+"""
+
+
+def assert_will_not_block():
+    if task_join_will_block():
+        raise Exception(E_WOULDBLOCK)
+
 
 class ResultBase(object):
     """Base class for all results"""
@@ -114,6 +126,7 @@ class AsyncResult(ResultBase):
         be re-raised.
 
         """
+        assert_will_not_block()
         if propagate and self.parent:
             for node in reversed(list(self._parents())):
                 node.get(propagate=True, timeout=timeout, interval=interval)
@@ -519,6 +532,7 @@ class ResultSet(ResultBase):
             seconds.
 
         """
+        assert_will_not_block()
         time_start = monotonic()
         remaining = None
 
@@ -570,6 +584,7 @@ class ResultSet(ResultBase):
         result backends.
 
         """
+        assert_will_not_block()
         order_index = None if callback else dict(
             (result.id, i) for i, result in enumerate(self.results)
         )

+ 2 - 0
celery/worker/components.py

@@ -16,6 +16,7 @@ from kombu.async.semaphore import DummyLock, LaxBoundedSemaphore
 from kombu.async.timer import Timer as _Timer
 
 from celery import bootsteps
+from celery._state import _set_task_join_will_block
 from celery.exceptions import ImproperlyConfigured
 from celery.five import string_t
 from celery.utils.log import worker_logger as logger
@@ -174,6 +175,7 @@ class Pool(bootsteps.StartStopStep):
             semaphore=semaphore,
             sched_strategy=self.optimization,
         )
+        _set_task_join_will_block(pool.task_join_will_block)
         return pool
 
     def info(self, w):