rpc.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. # -*- coding: utf-8 -*-
  2. """The ``RPC`` result backend for AMQP brokers.
  3. RPC-style result backend, using reply-to and one queue per client.
  4. """
  5. from __future__ import absolute_import, unicode_literals
  6. import time
  7. from kombu import Consumer, Exchange, Producer, Queue
  8. from kombu.common import maybe_declare
  9. from kombu.utils.compat import register_after_fork
  10. from kombu.utils.objects import cached_property
  11. from celery import current_task
  12. from celery import states
  13. from celery._state import task_join_will_block
  14. from celery.five import items, range
  15. from celery.utils.functional import dictfilter
  16. from celery.utils.time import maybe_s_to_ms
  17. from . import base
  18. from .async import AsyncBackendMixin, BaseResultConsumer
  19. __all__ = ['BacklogLimitExceeded', 'BaseRPCBackend', 'RPCBackend']
  20. class BacklogLimitExceeded(Exception):
  21. """Too much state history to fast-forward."""
  22. class NoCacheQueue(Queue):
  23. can_cache_declaration = False
  24. def _on_after_fork_cleanup_backend(backend):
  25. backend._after_fork()
  26. class ResultConsumer(BaseResultConsumer):
  27. Consumer = Consumer
  28. _connection = None
  29. _consumer = None
  30. def __init__(self, *args, **kwargs):
  31. super(ResultConsumer, self).__init__(*args, **kwargs)
  32. self._create_binding = self.backend._create_binding
  33. def start(self, initial_task_id, no_ack=True):
  34. self._connection = self.app.connection()
  35. initial_queue = self._create_binding(initial_task_id)
  36. self._consumer = self.Consumer(
  37. self._connection.default_channel, [initial_queue],
  38. callbacks=[self.on_state_change], no_ack=no_ack,
  39. accept=self.accept)
  40. self._consumer.consume()
  41. def drain_events(self, timeout=None):
  42. if self._connection:
  43. return self._connection.drain_events(timeout=timeout)
  44. elif timeout:
  45. time.sleep(timeout)
  46. def stop(self):
  47. try:
  48. self._consumer.cancel()
  49. finally:
  50. self._connection.close()
  51. def on_after_fork(self):
  52. self._consumer = None
  53. if self._connection is not None:
  54. self._connection.collect()
  55. self._connection = None
  56. def consume_from(self, task_id):
  57. if self._consumer is None:
  58. return self.start(task_id)
  59. queue = self._create_binding(task_id)
  60. if not self._consumer.consuming_from(queue):
  61. self._consumer.add_queue(queue)
  62. self._consumer.consume()
  63. def cancel_for(self, task_id):
  64. if self._consumer:
  65. self._consumer.cancel_by_queue(self._create_binding(task_id).name)
  66. class BaseRPCBackend(base.Backend, AsyncBackendMixin):
  67. """Base class for the RPC result backend."""
  68. Exchange = Exchange
  69. Queue = NoCacheQueue
  70. Consumer = Consumer
  71. Producer = Producer
  72. ResultConsumer = ResultConsumer
  73. BacklogLimitExceeded = BacklogLimitExceeded
  74. persistent = True
  75. supports_autoexpire = True
  76. supports_native_join = True
  77. retry_policy = {
  78. 'max_retries': 20,
  79. 'interval_start': 0,
  80. 'interval_step': 1,
  81. 'interval_max': 1,
  82. }
  83. def __init__(self, app, connection=None, exchange=None, exchange_type=None,
  84. persistent=None, serializer=None, auto_delete=True, **kwargs):
  85. super(BaseRPCBackend, self).__init__(app, **kwargs)
  86. conf = self.app.conf
  87. self._connection = connection
  88. self._out_of_band = {}
  89. self.persistent = self.prepare_persistent(persistent)
  90. self.delivery_mode = 2 if self.persistent else 1
  91. exchange = exchange or conf.result_exchange
  92. exchange_type = exchange_type or conf.result_exchange_type
  93. self.exchange = self._create_exchange(
  94. exchange, exchange_type, self.delivery_mode,
  95. )
  96. self.serializer = serializer or conf.result_serializer
  97. self.auto_delete = auto_delete
  98. self.queue_arguments = dictfilter({
  99. 'x-expires': maybe_s_to_ms(self.expires),
  100. })
  101. self.result_consumer = self.ResultConsumer(
  102. self, self.app, self.accept,
  103. self._pending_results, self._pending_messages,
  104. )
  105. if register_after_fork is not None:
  106. register_after_fork(self, _on_after_fork_cleanup_backend)
  107. def _after_fork(self):
  108. self._pending_results.clear()
  109. self.result_consumer._after_fork()
  110. def store_result(self, task_id, result, state,
  111. traceback=None, request=None, **kwargs):
  112. """Send task return value and state."""
  113. routing_key, correlation_id = self.destination_for(task_id, request)
  114. if not routing_key:
  115. return
  116. with self.app.amqp.producer_pool.acquire(block=True) as producer:
  117. producer.publish(
  118. {'task_id': task_id, 'status': state,
  119. 'result': self.encode_result(result, state),
  120. 'traceback': traceback,
  121. 'children': self.current_task_children(request)},
  122. exchange=self.exchange,
  123. routing_key=routing_key,
  124. correlation_id=correlation_id,
  125. serializer=self.serializer,
  126. retry=True, retry_policy=self.retry_policy,
  127. declare=self.on_reply_declare(task_id),
  128. delivery_mode=self.delivery_mode,
  129. )
  130. return result
  131. def on_out_of_band_result(self, task_id, message):
  132. if self.result_consumer:
  133. self.result_consumer.on_out_of_band_result(message)
  134. self._out_of_band[task_id] = message
  135. def get_task_meta(self, task_id, backlog_limit=1000):
  136. buffered = self._out_of_band.pop(task_id, None)
  137. if buffered:
  138. return self._set_cache_by_message(task_id, buffered)
  139. # Polling and using basic_get
  140. latest_by_id = {}
  141. prev = None
  142. for acc in self._slurp_from_queue(task_id, self.accept, backlog_limit):
  143. tid = self._get_message_task_id(acc)
  144. prev, latest_by_id[tid] = latest_by_id.get(tid), acc
  145. if prev:
  146. # backends aren't expected to keep history,
  147. # so we delete everything except the most recent state.
  148. prev.ack()
  149. prev = None
  150. latest = latest_by_id.pop(task_id, None)
  151. for tid, msg in items(latest_by_id):
  152. self.on_out_of_band_result(tid, msg)
  153. if latest:
  154. latest.requeue()
  155. return self._set_cache_by_message(task_id, latest)
  156. else:
  157. # no new state, use previous
  158. try:
  159. return self._cache[task_id]
  160. except KeyError:
  161. # result probably pending.
  162. return {'status': states.PENDING, 'result': None}
  163. poll = get_task_meta # XXX compat
  164. def _set_cache_by_message(self, task_id, message):
  165. payload = self._cache[task_id] = self.meta_from_decoded(
  166. message.payload)
  167. return payload
  168. def _slurp_from_queue(self, task_id, accept,
  169. limit=1000, no_ack=False):
  170. with self.app.pool.acquire_channel(block=True) as (_, channel):
  171. binding = self._create_binding(task_id)(channel)
  172. binding.declare()
  173. for i in range(limit):
  174. msg = binding.get(accept=accept, no_ack=no_ack)
  175. if not msg:
  176. break
  177. yield msg
  178. else:
  179. raise self.BacklogLimitExceeded(task_id)
  180. def _get_message_task_id(self, message):
  181. try:
  182. # try property first so we don't have to deserialize
  183. # the payload.
  184. return message.properties['correlation_id']
  185. except (AttributeError, KeyError):
  186. # message sent by old Celery version, need to deserialize.
  187. return message.payload['task_id']
  188. def revive(self, channel):
  189. pass
  190. def reload_task_result(self, task_id):
  191. raise NotImplementedError(
  192. 'reload_task_result is not supported by this backend.')
  193. def reload_group_result(self, task_id):
  194. """Reload group result, even if it has been previously fetched."""
  195. raise NotImplementedError(
  196. 'reload_group_result is not supported by this backend.')
  197. def save_group(self, group_id, result):
  198. raise NotImplementedError(
  199. 'save_group is not supported by this backend.')
  200. def restore_group(self, group_id, cache=True):
  201. raise NotImplementedError(
  202. 'restore_group is not supported by this backend.')
  203. def delete_group(self, group_id):
  204. raise NotImplementedError(
  205. 'delete_group is not supported by this backend.')
  206. def __reduce__(self, args=(), kwargs={}):
  207. return super(BaseRPCBackend, self).__reduce__(args, dict(
  208. kwargs,
  209. connection=self._connection,
  210. exchange=self.exchange.name,
  211. exchange_type=self.exchange.type,
  212. persistent=self.persistent,
  213. serializer=self.serializer,
  214. auto_delete=self.auto_delete,
  215. expires=self.expires,
  216. ))
  217. class RPCBackend(BaseRPCBackend):
  218. """RPC result backend."""
  219. persistent = False
  220. class Consumer(Consumer):
  221. auto_declare = False
  222. def _create_exchange(self, name, type='direct', delivery_mode=2):
  223. # uses direct to queue routing (anon exchange).
  224. return Exchange(None)
  225. def _create_binding(self, task_id):
  226. return self.binding
  227. def on_task_call(self, producer, task_id):
  228. if not task_join_will_block():
  229. maybe_declare(self.binding(producer.channel), retry=True)
  230. def rkey(self, task_id):
  231. return task_id
  232. def destination_for(self, task_id, request):
  233. # Request is a new argument for backends, so must still support
  234. # old code that rely on current_task
  235. try:
  236. request = request or current_task.request
  237. except AttributeError:
  238. raise RuntimeError(
  239. 'RPC backend missing task request for {0!r}'.format(task_id))
  240. return request.reply_to, request.correlation_id or task_id
  241. def on_reply_declare(self, task_id):
  242. pass
  243. def on_result_fulfilled(self, result):
  244. pass
  245. def as_uri(self, include_password=True):
  246. return 'rpc://'
  247. @property
  248. def binding(self):
  249. return self.Queue(
  250. self.oid, self.exchange, self.oid,
  251. durable=False, auto_delete=True
  252. )
  253. @cached_property
  254. def oid(self):
  255. return self.app.oid