async.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. """Async I/O backend support utilities."""
  2. from __future__ import absolute_import, unicode_literals
  3. import socket
  4. import threading
  5. from collections import deque
  6. from time import sleep
  7. from weakref import WeakKeyDictionary
  8. from kombu.utils.compat import detect_environment
  9. from kombu.utils.objects import cached_property
  10. from celery import states
  11. from celery.exceptions import TimeoutError
  12. from celery.five import Empty, monotonic
  13. from celery.utils.threads import THREAD_TIMEOUT_MAX
  14. __all__ = (
  15. 'AsyncBackendMixin', 'BaseResultConsumer', 'Drainer',
  16. 'register_drainer',
  17. )
  18. drainers = {}
  19. def register_drainer(name):
  20. """Decorator used to register a new result drainer type."""
  21. def _inner(cls):
  22. drainers[name] = cls
  23. return cls
  24. return _inner
  25. @register_drainer('default')
  26. class Drainer(object):
  27. """Result draining service."""
  28. def __init__(self, result_consumer):
  29. self.result_consumer = result_consumer
  30. def start(self):
  31. pass
  32. def stop(self):
  33. pass
  34. def drain_events_until(self, p, timeout=None, on_interval=None, wait=None):
  35. wait = wait or self.result_consumer.drain_events
  36. time_start = monotonic()
  37. while 1:
  38. # Total time spent may exceed a single call to wait()
  39. if timeout and monotonic() - time_start >= timeout:
  40. raise socket.timeout()
  41. try:
  42. yield self.wait_for(p, wait, timeout=1)
  43. except socket.timeout:
  44. pass
  45. if on_interval:
  46. on_interval()
  47. if p.ready: # got event on the wanted channel.
  48. break
  49. def wait_for(self, p, wait, timeout=None):
  50. wait(timeout=timeout)
  51. class greenletDrainer(Drainer):
  52. spawn = None
  53. _g = None
  54. def __init__(self, *args, **kwargs):
  55. super(greenletDrainer, self).__init__(*args, **kwargs)
  56. self._started = threading.Event()
  57. self._stopped = threading.Event()
  58. self._shutdown = threading.Event()
  59. def run(self):
  60. self._started.set()
  61. while not self._stopped.is_set():
  62. try:
  63. self.result_consumer.drain_events(timeout=1)
  64. except socket.timeout:
  65. pass
  66. self._shutdown.set()
  67. def start(self):
  68. if not self._started.is_set():
  69. self._g = self.spawn(self.run)
  70. self._started.wait()
  71. def stop(self):
  72. self._stopped.set()
  73. self._shutdown.wait(THREAD_TIMEOUT_MAX)
  74. def wait_for(self, p, wait, timeout=None):
  75. self.start()
  76. if not p.ready:
  77. sleep(0)
  78. @register_drainer('eventlet')
  79. class eventletDrainer(greenletDrainer):
  80. @cached_property
  81. def spawn(self):
  82. from eventlet import spawn
  83. return spawn
  84. @register_drainer('gevent')
  85. class geventDrainer(greenletDrainer):
  86. @cached_property
  87. def spawn(self):
  88. from gevent import spawn
  89. return spawn
  90. class AsyncBackendMixin(object):
  91. """Mixin for backends that enables the async API."""
  92. def _collect_into(self, result, bucket):
  93. self.result_consumer.buckets[result] = bucket
  94. def iter_native(self, result, no_ack=True, **kwargs):
  95. self._ensure_not_eager()
  96. results = result.results
  97. if not results:
  98. raise StopIteration()
  99. # we tell the result consumer to put consumed results
  100. # into these buckets.
  101. bucket = deque()
  102. for node in results:
  103. if node._cache:
  104. bucket.append(node)
  105. else:
  106. self._collect_into(node, bucket)
  107. for _ in self._wait_for_pending(result, no_ack=no_ack, **kwargs):
  108. while bucket:
  109. node = bucket.popleft()
  110. yield node.id, node._cache
  111. while bucket:
  112. node = bucket.popleft()
  113. yield node.id, node._cache
  114. def add_pending_result(self, result, weak=False, start_drainer=True):
  115. if start_drainer:
  116. self.result_consumer.drainer.start()
  117. try:
  118. self._maybe_resolve_from_buffer(result)
  119. except Empty:
  120. self._add_pending_result(result.id, result, weak=weak)
  121. return result
  122. def _maybe_resolve_from_buffer(self, result):
  123. result._maybe_set_cache(self._pending_messages.take(result.id))
  124. def _add_pending_result(self, task_id, result, weak=False):
  125. concrete, weak_ = self._pending_results
  126. if task_id not in weak_ and result.id not in concrete:
  127. (weak_ if weak else concrete)[task_id] = result
  128. self.result_consumer.consume_from(task_id)
  129. def add_pending_results(self, results, weak=False):
  130. self.result_consumer.drainer.start()
  131. return [self.add_pending_result(result, weak=weak, start_drainer=False)
  132. for result in results]
  133. def remove_pending_result(self, result):
  134. self._remove_pending_result(result.id)
  135. self.on_result_fulfilled(result)
  136. return result
  137. def _remove_pending_result(self, task_id):
  138. for map in self._pending_results:
  139. map.pop(task_id, None)
  140. def on_result_fulfilled(self, result):
  141. self.result_consumer.cancel_for(result.id)
  142. def wait_for_pending(self, result,
  143. callback=None, propagate=True, **kwargs):
  144. self._ensure_not_eager()
  145. for _ in self._wait_for_pending(result, **kwargs):
  146. pass
  147. return result.maybe_throw(callback=callback, propagate=propagate)
  148. def _wait_for_pending(self, result,
  149. timeout=None, on_interval=None, on_message=None,
  150. **kwargs):
  151. return self.result_consumer._wait_for_pending(
  152. result, timeout=timeout,
  153. on_interval=on_interval, on_message=on_message,
  154. )
  155. @property
  156. def is_async(self):
  157. return True
  158. class BaseResultConsumer(object):
  159. """Manager responsible for consuming result messages."""
  160. def __init__(self, backend, app, accept,
  161. pending_results, pending_messages):
  162. self.backend = backend
  163. self.app = app
  164. self.accept = accept
  165. self._pending_results = pending_results
  166. self._pending_messages = pending_messages
  167. self.on_message = None
  168. self.buckets = WeakKeyDictionary()
  169. self.drainer = drainers[detect_environment()](self)
  170. def start(self, initial_task_id, **kwargs):
  171. raise NotImplementedError()
  172. def stop(self):
  173. pass
  174. def drain_events(self, timeout=None):
  175. raise NotImplementedError()
  176. def consume_from(self, task_id):
  177. raise NotImplementedError()
  178. def cancel_for(self, task_id):
  179. raise NotImplementedError()
  180. def _after_fork(self):
  181. self.buckets.clear()
  182. self.buckets = WeakKeyDictionary()
  183. self.on_message = None
  184. self.on_after_fork()
  185. def on_after_fork(self):
  186. pass
  187. def drain_events_until(self, p, timeout=None, on_interval=None):
  188. return self.drainer.drain_events_until(
  189. p, timeout=timeout, on_interval=on_interval)
  190. def _wait_for_pending(self, result,
  191. timeout=None, on_interval=None, on_message=None,
  192. **kwargs):
  193. self.on_wait_for_pending(result, timeout=timeout, **kwargs)
  194. prev_on_m, self.on_message = self.on_message, on_message
  195. try:
  196. for _ in self.drain_events_until(
  197. result.on_ready, timeout=timeout,
  198. on_interval=on_interval):
  199. yield
  200. sleep(0)
  201. except socket.timeout:
  202. raise TimeoutError('The operation timed out.')
  203. finally:
  204. self.on_message = prev_on_m
  205. def on_wait_for_pending(self, result, timeout=None, **kwargs):
  206. pass
  207. def on_out_of_band_result(self, message):
  208. self.on_state_change(message.payload, message)
  209. def _get_pending_result(self, task_id):
  210. for mapping in self._pending_results:
  211. try:
  212. return mapping[task_id]
  213. except KeyError:
  214. pass
  215. raise KeyError(task_id)
  216. def on_state_change(self, meta, message):
  217. if self.on_message:
  218. self.on_message(meta)
  219. if meta['status'] in states.READY_STATES:
  220. task_id = meta['task_id']
  221. try:
  222. result = self._get_pending_result(task_id)
  223. except KeyError:
  224. # send to buffer in case we received this result
  225. # before it was added to _pending_results.
  226. self._pending_messages.put(task_id, meta)
  227. else:
  228. result._maybe_set_cache(meta)
  229. buckets = self.buckets
  230. try:
  231. # remove bucket for this result, since it's fulfilled
  232. bucket = buckets.pop(result)
  233. except KeyError:
  234. pass
  235. else:
  236. # send to waiter via bucket
  237. bucket.append(result)
  238. sleep(0)