async.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. """Async I/O backend support utilities."""
  2. from __future__ import absolute_import, unicode_literals
  3. import socket
  4. import threading
  5. from collections import deque
  6. from time import sleep
  7. from weakref import WeakKeyDictionary
  8. from kombu.utils.compat import detect_environment
  9. from kombu.utils.objects import cached_property
  10. from celery import states
  11. from celery.exceptions import TimeoutError
  12. from celery.five import Empty, monotonic
  13. from celery.utils.threads import THREAD_TIMEOUT_MAX
  14. __all__ = [
  15. 'AsyncBackendMixin', 'BaseResultConsumer', 'Drainer',
  16. 'register_drainer',
  17. ]
  18. drainers = {}
  19. def register_drainer(name):
  20. """Decorator used to register a new result drainer type."""
  21. def _inner(cls):
  22. drainers[name] = cls
  23. return cls
  24. return _inner
  25. @register_drainer('default')
  26. class Drainer(object):
  27. """Result draining service."""
  28. def __init__(self, result_consumer):
  29. self.result_consumer = result_consumer
  30. def start(self):
  31. pass
  32. def stop(self):
  33. pass
  34. def drain_events_until(self, p, timeout=None, on_interval=None,
  35. monotonic=monotonic, wait=None):
  36. wait = wait or self.result_consumer.drain_events
  37. time_start = monotonic()
  38. while 1:
  39. # Total time spent may exceed a single call to wait()
  40. if timeout and monotonic() - time_start >= timeout:
  41. raise socket.timeout()
  42. try:
  43. yield self.wait_for(p, wait, timeout=1)
  44. except socket.timeout:
  45. pass
  46. if on_interval:
  47. on_interval()
  48. if p.ready: # got event on the wanted channel.
  49. break
  50. def wait_for(self, p, wait, timeout=None):
  51. wait(timeout=timeout)
  52. class greenletDrainer(Drainer):
  53. spawn = None
  54. _g = None
  55. def __init__(self, *args, **kwargs):
  56. super(greenletDrainer, self).__init__(*args, **kwargs)
  57. self._started = threading.Event()
  58. self._stopped = threading.Event()
  59. self._shutdown = threading.Event()
  60. def run(self):
  61. self._started.set()
  62. while not self._stopped.is_set():
  63. try:
  64. self.result_consumer.drain_events(timeout=1)
  65. except socket.timeout:
  66. pass
  67. self._shutdown.set()
  68. def start(self):
  69. if not self._started.is_set():
  70. self._g = self.spawn(self.run)
  71. self._started.wait()
  72. def stop(self):
  73. self._stopped.set()
  74. self._shutdown.wait(THREAD_TIMEOUT_MAX)
  75. def wait_for(self, p, wait, timeout=None):
  76. self.start()
  77. if not p.ready:
  78. sleep(0)
  79. @register_drainer('eventlet')
  80. class eventletDrainer(greenletDrainer):
  81. @cached_property
  82. def spawn(self):
  83. from eventlet import spawn
  84. return spawn
  85. @register_drainer('gevent')
  86. class geventDrainer(greenletDrainer):
  87. @cached_property
  88. def spawn(self):
  89. from gevent import spawn
  90. return spawn
  91. class AsyncBackendMixin(object):
  92. """Mixin for backends that enables the async API."""
  93. def _collect_into(self, result, bucket):
  94. self.result_consumer.buckets[result] = bucket
  95. def iter_native(self, result, no_ack=True, **kwargs):
  96. self._ensure_not_eager()
  97. results = result.results
  98. if not results:
  99. raise StopIteration()
  100. # we tell the result consumer to put consumed results
  101. # into these buckets.
  102. bucket = deque()
  103. for node in results:
  104. if node._cache:
  105. bucket.append(node)
  106. else:
  107. self._collect_into(node, bucket)
  108. for _ in self._wait_for_pending(result, no_ack=no_ack, **kwargs):
  109. while bucket:
  110. node = bucket.popleft()
  111. yield node.id, node._cache
  112. while bucket:
  113. node = bucket.popleft()
  114. yield node.id, node._cache
  115. def add_pending_result(self, result, weak=False, start_drainer=True):
  116. if start_drainer:
  117. self.result_consumer.drainer.start()
  118. try:
  119. self._maybe_resolve_from_buffer(result)
  120. except Empty:
  121. self._add_pending_result(result.id, result, weak=weak)
  122. return result
  123. def _maybe_resolve_from_buffer(self, result):
  124. result._maybe_set_cache(self._pending_messages.take(result.id))
  125. def _add_pending_result(self, task_id, result, weak=False):
  126. weak, concrete = self._pending_results
  127. if task_id not in weak and result.id not in concrete:
  128. (weak if weak else concrete)[task_id] = result
  129. self.result_consumer.consume_from(task_id)
  130. def add_pending_results(self, results, weak=False):
  131. self.result_consumer.drainer.start()
  132. return [self.add_pending_result(result, weak=weak, start_drainer=False)
  133. for result in results]
  134. def remove_pending_result(self, result):
  135. self._remove_pending_result(result.id)
  136. self.on_result_fulfilled(result)
  137. return result
  138. def _remove_pending_result(self, task_id):
  139. for map in self._pending_results:
  140. map.pop(task_id, None)
  141. def on_result_fulfilled(self, result):
  142. self.result_consumer.cancel_for(result.id)
  143. def wait_for_pending(self, result,
  144. callback=None, propagate=True, **kwargs):
  145. self._ensure_not_eager()
  146. for _ in self._wait_for_pending(result, **kwargs):
  147. pass
  148. return result.maybe_throw(callback=callback, propagate=propagate)
  149. def _wait_for_pending(self, result,
  150. timeout=None, on_interval=None, on_message=None,
  151. **kwargs):
  152. return self.result_consumer._wait_for_pending(
  153. result, timeout=timeout,
  154. on_interval=on_interval, on_message=on_message,
  155. )
  156. @property
  157. def is_async(self):
  158. return True
  159. class BaseResultConsumer(object):
  160. """Manager responsible for consuming result messages."""
  161. def __init__(self, backend, app, accept,
  162. pending_results, pending_messages):
  163. self.backend = backend
  164. self.app = app
  165. self.accept = accept
  166. self._pending_results = pending_results
  167. self._pending_messages = pending_messages
  168. self.on_message = None
  169. self.buckets = WeakKeyDictionary()
  170. self.drainer = drainers[detect_environment()](self)
  171. def start(self):
  172. raise NotImplementedError()
  173. def stop(self):
  174. pass
  175. def drain_events(self, timeout=None):
  176. raise NotImplementedError()
  177. def consume_from(self, task_id):
  178. raise NotImplementedError()
  179. def cancel_for(self, task_id):
  180. raise NotImplementedError()
  181. def _after_fork(self):
  182. self.buckets.clear()
  183. self.buckets = WeakKeyDictionary()
  184. self.on_message = None
  185. self.on_after_fork()
  186. def on_after_fork(self):
  187. pass
  188. def drain_events_until(self, p, timeout=None, on_interval=None):
  189. return self.drainer.drain_events_until(
  190. p, timeout=timeout, on_interval=on_interval)
  191. def _wait_for_pending(self, result,
  192. timeout=None, on_interval=None, on_message=None,
  193. **kwargs):
  194. self.on_wait_for_pending(result, timeout=timeout, **kwargs)
  195. prev_on_m, self.on_message = self.on_message, on_message
  196. try:
  197. for _ in self.drain_events_until(
  198. result.on_ready, timeout=timeout,
  199. on_interval=on_interval):
  200. yield
  201. sleep(0)
  202. except socket.timeout:
  203. raise TimeoutError('The operation timed out.')
  204. finally:
  205. self.on_message = prev_on_m
  206. def on_wait_for_pending(self, result, timeout=None, **kwargs):
  207. pass
  208. def on_out_of_band_result(self, message):
  209. self.on_state_change(message.payload, message)
  210. def _get_pending_result(self, task_id):
  211. for mapping in self._pending_results:
  212. try:
  213. return mapping[task_id]
  214. except KeyError:
  215. pass
  216. raise KeyError(task_id)
  217. def on_state_change(self, meta, message):
  218. if self.on_message:
  219. self.on_message(meta)
  220. if meta['status'] in states.READY_STATES:
  221. task_id = meta['task_id']
  222. try:
  223. result = self._get_pending_result(task_id)
  224. except KeyError:
  225. # send to buffer in case we received this result
  226. # before it was added to _pending_results.
  227. self._pending_messages.put(task_id, meta)
  228. else:
  229. result._maybe_set_cache(meta)
  230. buckets = self.buckets
  231. try:
  232. # remove bucket for this result, since it's fulfilled
  233. bucket = buckets.pop(result)
  234. except KeyError:
  235. pass
  236. else:
  237. # send to waiter via bucket
  238. bucket.append(result)
  239. sleep(0)