async.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. """Async I/O backend support utilities."""
  2. import socket
  3. import threading
  4. from collections import deque
  5. from time import monotonic, sleep
  6. from weakref import WeakKeyDictionary
  7. from queue import Empty
  8. from kombu.utils.compat import detect_environment
  9. from kombu.utils.objects import cached_property
  10. from celery import states
  11. from celery.exceptions import TimeoutError
  12. __all__ = [
  13. 'AsyncBackendMixin', 'BaseResultConsumer', 'Drainer',
  14. 'register_drainer',
  15. ]
  16. drainers = {}
  17. def register_drainer(name):
  18. """Decorator used to register a new result drainer type."""
  19. def _inner(cls):
  20. drainers[name] = cls
  21. return cls
  22. return _inner
  23. @register_drainer('default')
  24. class Drainer:
  25. """Result draining service."""
  26. def __init__(self, result_consumer):
  27. self.result_consumer = result_consumer
  28. def start(self):
  29. pass
  30. def stop(self):
  31. pass
  32. def drain_events_until(self, p, timeout=None, on_interval=None, wait=None):
  33. wait = wait or self.result_consumer.drain_events
  34. time_start = monotonic()
  35. while 1:
  36. # Total time spent may exceed a single call to wait()
  37. if timeout and monotonic() - time_start >= timeout:
  38. raise socket.timeout()
  39. try:
  40. yield self.wait_for(p, wait, timeout=1)
  41. except socket.timeout:
  42. pass
  43. if on_interval:
  44. on_interval()
  45. if p.ready: # got event on the wanted channel.
  46. break
  47. def wait_for(self, p, wait, timeout=None):
  48. wait(timeout=timeout)
  49. class greenletDrainer(Drainer):
  50. spawn = None
  51. _g = None
  52. def __init__(self, *args, **kwargs):
  53. super(greenletDrainer, self).__init__(*args, **kwargs)
  54. self._started = threading.Event()
  55. self._stopped = threading.Event()
  56. self._shutdown = threading.Event()
  57. def run(self):
  58. self._started.set()
  59. while not self._stopped.is_set():
  60. try:
  61. self.result_consumer.drain_events(timeout=1)
  62. except socket.timeout:
  63. pass
  64. self._shutdown.set()
  65. def start(self):
  66. if not self._started.is_set():
  67. self._g = self.spawn(self.run)
  68. self._started.wait()
  69. def stop(self):
  70. self._stopped.set()
  71. self._shutdown.wait(threading.TIMEOUT_MAX)
  72. def wait_for(self, p, wait, timeout=None):
  73. self.start()
  74. if not p.ready:
  75. sleep(0)
  76. @register_drainer('eventlet')
  77. class eventletDrainer(greenletDrainer):
  78. @cached_property
  79. def spawn(self):
  80. from eventlet import spawn
  81. return spawn
  82. @register_drainer('gevent')
  83. class geventDrainer(greenletDrainer):
  84. @cached_property
  85. def spawn(self):
  86. from gevent import spawn
  87. return spawn
  88. class AsyncBackendMixin:
  89. """Mixin for backends that enables the async API."""
  90. def _collect_into(self, result, bucket):
  91. self.result_consumer.buckets[result] = bucket
  92. def iter_native(self, result, no_ack=True, **kwargs):
  93. self._ensure_not_eager()
  94. results = result.results
  95. if not results:
  96. raise StopIteration()
  97. # we tell the result consumer to put consumed results
  98. # into these buckets.
  99. bucket = deque()
  100. for node in results:
  101. if node._cache:
  102. bucket.append(node)
  103. else:
  104. self._collect_into(node, bucket)
  105. for _ in self._wait_for_pending(result, no_ack=no_ack, **kwargs):
  106. while bucket:
  107. node = bucket.popleft()
  108. yield node.id, node._cache
  109. while bucket:
  110. node = bucket.popleft()
  111. yield node.id, node._cache
  112. def add_pending_result(self, result, weak=False, start_drainer=True):
  113. if start_drainer:
  114. self.result_consumer.drainer.start()
  115. try:
  116. self._maybe_resolve_from_buffer(result)
  117. except Empty:
  118. self._add_pending_result(result.id, result, weak=weak)
  119. return result
  120. def _maybe_resolve_from_buffer(self, result):
  121. result._maybe_set_cache(self._pending_messages.take(result.id))
  122. def _add_pending_result(self, task_id, result, weak=False):
  123. concrete, weak_ = self._pending_results
  124. if task_id not in weak_ and result.id not in concrete:
  125. (weak_ if weak else concrete)[task_id] = result
  126. self.result_consumer.consume_from(task_id)
  127. def add_pending_results(self, results, weak=False):
  128. self.result_consumer.drainer.start()
  129. return [self.add_pending_result(result, weak=weak, start_drainer=False)
  130. for result in results]
  131. def remove_pending_result(self, result):
  132. self._remove_pending_result(result.id)
  133. self.on_result_fulfilled(result)
  134. return result
  135. def _remove_pending_result(self, task_id):
  136. for map in self._pending_results:
  137. map.pop(task_id, None)
  138. def on_result_fulfilled(self, result):
  139. self.result_consumer.cancel_for(result.id)
  140. def wait_for_pending(self, result,
  141. callback=None, propagate=True, **kwargs):
  142. self._ensure_not_eager()
  143. for _ in self._wait_for_pending(result, **kwargs):
  144. pass
  145. return result.maybe_throw(callback=callback, propagate=propagate)
  146. def _wait_for_pending(self, result,
  147. timeout=None, on_interval=None, on_message=None,
  148. **kwargs):
  149. return self.result_consumer._wait_for_pending(
  150. result, timeout=timeout,
  151. on_interval=on_interval, on_message=on_message,
  152. )
  153. @property
  154. def is_async(self):
  155. return True
  156. class BaseResultConsumer:
  157. """Manager responsible for consuming result messages."""
  158. def __init__(self, backend, app, accept,
  159. pending_results, pending_messages):
  160. self.backend = backend
  161. self.app = app
  162. self.accept = accept
  163. self._pending_results = pending_results
  164. self._pending_messages = pending_messages
  165. self.on_message = None
  166. self.buckets = WeakKeyDictionary()
  167. self.drainer = drainers[detect_environment()](self)
  168. def start(self, initial_task_id, **kwargs):
  169. raise NotImplementedError()
  170. def stop(self):
  171. pass
  172. def drain_events(self, timeout=None):
  173. raise NotImplementedError()
  174. def consume_from(self, task_id):
  175. raise NotImplementedError()
  176. def cancel_for(self, task_id):
  177. raise NotImplementedError()
  178. def _after_fork(self):
  179. self.buckets.clear()
  180. self.buckets = WeakKeyDictionary()
  181. self.on_message = None
  182. self.on_after_fork()
  183. def on_after_fork(self):
  184. pass
  185. def drain_events_until(self, p, timeout=None, on_interval=None):
  186. return self.drainer.drain_events_until(
  187. p, timeout=timeout, on_interval=on_interval)
  188. def _wait_for_pending(self, result,
  189. timeout=None, on_interval=None, on_message=None,
  190. **kwargs):
  191. self.on_wait_for_pending(result, timeout=timeout, **kwargs)
  192. prev_on_m, self.on_message = self.on_message, on_message
  193. try:
  194. for _ in self.drain_events_until(
  195. result.on_ready, timeout=timeout,
  196. on_interval=on_interval):
  197. yield
  198. sleep(0)
  199. except socket.timeout:
  200. raise TimeoutError('The operation timed out.')
  201. finally:
  202. self.on_message = prev_on_m
  203. def on_wait_for_pending(self, result, timeout=None, **kwargs):
  204. pass
  205. def on_out_of_band_result(self, message):
  206. self.on_state_change(message.payload, message)
  207. def _get_pending_result(self, task_id):
  208. for mapping in self._pending_results:
  209. try:
  210. return mapping[task_id]
  211. except KeyError:
  212. pass
  213. raise KeyError(task_id)
  214. def on_state_change(self, meta, message):
  215. if self.on_message:
  216. self.on_message(meta)
  217. if meta['status'] in states.READY_STATES:
  218. task_id = meta['task_id']
  219. try:
  220. result = self._get_pending_result(task_id)
  221. except KeyError:
  222. # send to buffer in case we received this result
  223. # before it was added to _pending_results.
  224. self._pending_messages.put(task_id, meta)
  225. else:
  226. result._maybe_set_cache(meta)
  227. buckets = self.buckets
  228. try:
  229. # remove bucket for this result, since it's fulfilled
  230. bucket = buckets.pop(result)
  231. except KeyError:
  232. pass
  233. else:
  234. # send to waiter via bucket
  235. bucket.append(result)
  236. sleep(0)