amqp.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. # -*- coding: utf-8 -*-
  2. """
  3. celery.backends.amqp
  4. ~~~~~~~~~~~~~~~~~~~~
  5. The AMQP result backend.
  6. This backend publishes results as messages.
  7. """
  8. from __future__ import absolute_import
  9. import socket
  10. from collections import deque
  11. from operator import itemgetter
  12. from kombu import Exchange, Queue, Producer, Consumer
  13. from celery import states
  14. from celery.exceptions import TimeoutError
  15. from celery.five import range, monotonic
  16. from celery.utils.functional import dictfilter
  17. from celery.utils.log import get_logger
  18. from celery.utils.timeutils import maybe_s_to_ms
  19. from .base import BaseBackend
  20. __all__ = ['BacklogLimitExceeded', 'AMQPBackend']
  21. logger = get_logger(__name__)
  22. class BacklogLimitExceeded(Exception):
  23. """Too much state history to fast-forward."""
  24. def repair_uuid(s):
  25. # Historically the dashes in UUIDS are removed from AMQ entity names,
  26. # but there is no known reason to. Hopefully we'll be able to fix
  27. # this in v4.0.
  28. return '%s-%s-%s-%s-%s' % (s[:8], s[8:12], s[12:16], s[16:20], s[20:])
  29. class NoCacheQueue(Queue):
  30. can_cache_declaration = False
  31. class ResultConsumer(object):
  32. Consumer = Consumer
  33. def __init__(self, backend, app, accept, pending_results):
  34. self.backend = backend
  35. self.app = app
  36. self.accept = accept
  37. self._pending_results = pending_results
  38. self._consumer = None
  39. self._conn = None
  40. self.on_message = None
  41. self.bucket = None
  42. def consume(self, task_id, timeout=None, no_ack=True, on_interval=None):
  43. wait = self.drain_events
  44. with self.app.pool.acquire_channel(block=True) as (conn, channel):
  45. binding = self.backend._create_binding(task_id)
  46. with self.Consumer(channel, binding,
  47. no_ack=no_ack, accept=self.accept) as consumer:
  48. while 1:
  49. try:
  50. return wait(
  51. conn, consumer, timeout, on_interval)[task_id]
  52. except KeyError:
  53. continue
  54. def wait_for_pending(self, result,
  55. callback=None, propagate=True, **kwargs):
  56. for _ in self._wait_for_pending(result, **kwargs):
  57. pass
  58. return result.maybe_throw(callback=callback, propagate=propagate)
  59. def _wait_for_pending(self, result, timeout=None, interval=0.5,
  60. no_ack=True, on_interval=None, callback=None,
  61. on_message=None, propagate=True):
  62. prev_on_m, self.on_message = self.on_message, on_message
  63. try:
  64. for _ in self.drain_events_until(
  65. result.on_ready, timeout=timeout,
  66. on_interval=on_interval):
  67. yield
  68. except socket.timeout:
  69. raise TimeoutError('The operation timed out.')
  70. finally:
  71. self.on_message = prev_on_m
  72. def collect_for_pending(self, result, bucket=None, **kwargs):
  73. prev_bucket, self.bucket = self.bucket, bucket
  74. try:
  75. for _ in self._wait_for_pending(result, **kwargs):
  76. yield
  77. finally:
  78. self.bucket = prev_bucket
  79. def start(self, initial_queue, no_ack=True):
  80. self._conn = self.app.connection()
  81. self._consumer = self.Consumer(
  82. self._conn.default_channel, [initial_queue],
  83. callbacks=[self.on_state_change], no_ack=no_ack,
  84. accept=self.accept)
  85. self._consumer.consume()
  86. def stop(self):
  87. try:
  88. self._consumer.cancel()
  89. finally:
  90. self._connection.close()
  91. def consume_from(self, queue):
  92. if self._consumer is None:
  93. return self.start(queue)
  94. if not self._consumer.consuming_from(queue):
  95. self._consumer.add_queue(queue)
  96. self._consumer.consume()
  97. def cancel_for(self, queue):
  98. self._consumer.cancel_by_queue(queue)
  99. def on_state_change(self, meta, message):
  100. if self.on_message:
  101. self.on_message(meta)
  102. if meta['status'] in states.READY_STATES:
  103. try:
  104. result = self._pending_results[meta['task_id']]
  105. except KeyError:
  106. return
  107. result._maybe_set_cache(meta)
  108. if self.bucket is not None:
  109. self.bucket.append(result)
  110. def drain_events_until(self, p, timeout=None, on_interval=None,
  111. monotonic=monotonic, wait=None):
  112. wait = wait or self._conn.drain_events
  113. time_start = monotonic()
  114. while 1:
  115. # Total time spent may exceed a single call to wait()
  116. if timeout and monotonic() - time_start >= timeout:
  117. raise socket.timeout()
  118. try:
  119. yield wait(timeout=1)
  120. except socket.timeout:
  121. pass
  122. if on_interval:
  123. on_interval()
  124. if p.ready: # got event on the wanted channel.
  125. break
  126. class AMQPBackend(BaseBackend):
  127. """Publishes results by sending messages."""
  128. Exchange = Exchange
  129. Queue = NoCacheQueue
  130. Consumer = Consumer
  131. Producer = Producer
  132. ResultConsumer = ResultConsumer
  133. BacklogLimitExceeded = BacklogLimitExceeded
  134. persistent = True
  135. supports_autoexpire = True
  136. supports_native_join = True
  137. retry_policy = {
  138. 'max_retries': 20,
  139. 'interval_start': 0,
  140. 'interval_step': 1,
  141. 'interval_max': 1,
  142. }
  143. def __init__(self, app, connection=None, exchange=None, exchange_type=None,
  144. persistent=None, serializer=None, auto_delete=True, **kwargs):
  145. super(AMQPBackend, self).__init__(app, **kwargs)
  146. conf = self.app.conf
  147. self._connection = connection
  148. self.persistent = self.prepare_persistent(persistent)
  149. self.delivery_mode = 2 if self.persistent else 1
  150. exchange = exchange or conf.result_exchange
  151. exchange_type = exchange_type or conf.result_exchange_type
  152. self.exchange = self._create_exchange(
  153. exchange, exchange_type, self.delivery_mode,
  154. )
  155. self.serializer = serializer or conf.result_serializer
  156. self.auto_delete = auto_delete
  157. self.queue_arguments = dictfilter({
  158. 'x-expires': maybe_s_to_ms(self.expires),
  159. })
  160. self.result_consumer = self.ResultConsumer(
  161. self, self.app, self.accept, self._pending_results)
  162. def _create_exchange(self, name, type='direct', delivery_mode=2):
  163. return self.Exchange(name=name,
  164. type=type,
  165. delivery_mode=delivery_mode,
  166. durable=self.persistent,
  167. auto_delete=False)
  168. def _create_binding(self, task_id):
  169. name = self.rkey(task_id)
  170. return self.Queue(name=name,
  171. exchange=self.exchange,
  172. routing_key=name,
  173. durable=self.persistent,
  174. auto_delete=self.auto_delete,
  175. queue_arguments=self.queue_arguments)
  176. def revive(self, channel):
  177. pass
  178. def rkey(self, task_id):
  179. return task_id.replace('-', '')
  180. def destination_for(self, task_id, request):
  181. if request:
  182. return self.rkey(task_id), request.correlation_id or task_id
  183. return self.rkey(task_id), task_id
  184. def store_result(self, task_id, result, state,
  185. traceback=None, request=None, **kwargs):
  186. """Send task return value and state."""
  187. routing_key, correlation_id = self.destination_for(task_id, request)
  188. if not routing_key:
  189. return
  190. with self.app.amqp.producer_pool.acquire(block=True) as producer:
  191. producer.publish(
  192. {'task_id': task_id, 'status': state,
  193. 'result': self.encode_result(result, state),
  194. 'traceback': traceback,
  195. 'children': self.current_task_children(request)},
  196. exchange=self.exchange,
  197. routing_key=routing_key,
  198. correlation_id=correlation_id,
  199. serializer=self.serializer,
  200. retry=True, retry_policy=self.retry_policy,
  201. declare=self.on_reply_declare(task_id),
  202. delivery_mode=self.delivery_mode,
  203. )
  204. return result
  205. def on_reply_declare(self, task_id):
  206. return [self._create_binding(task_id)]
  207. def get_task_meta(self, task_id, backlog_limit=1000):
  208. # Polling and using basic_get
  209. with self.app.pool.acquire_channel(block=True) as (_, channel):
  210. binding = self._create_binding(task_id)(channel)
  211. binding.declare()
  212. prev = latest = acc = None
  213. for i in range(backlog_limit): # spool ffwd
  214. acc = binding.get(
  215. accept=self.accept, no_ack=False,
  216. )
  217. if not acc: # no more messages
  218. break
  219. if acc.payload['task_id'] == task_id:
  220. prev, latest = latest, acc
  221. if prev:
  222. # backends are not expected to keep history,
  223. # so we delete everything except the most recent state.
  224. prev.ack()
  225. prev = None
  226. else:
  227. raise self.BacklogLimitExceeded(task_id)
  228. if latest:
  229. payload = self._cache[task_id] = self.meta_from_decoded(
  230. latest.payload)
  231. latest.requeue()
  232. return payload
  233. else:
  234. # no new state, use previous
  235. try:
  236. return self._cache[task_id]
  237. except KeyError:
  238. # result probably pending.
  239. return {'status': states.PENDING, 'result': None}
  240. poll = get_task_meta # XXX compat
  241. def wait_for_pending(self, result, timeout=None, interval=0.5,
  242. no_ack=True, on_interval=None, on_message=None,
  243. callback=None, propagate=True):
  244. return self.result_consumer.wait_for_pending(
  245. result, timeout=timeout, interval=interval,
  246. no_ack=no_ack, on_interval=on_interval,
  247. callback=callback, on_message=on_message, propagate=propagate,
  248. )
  249. def collect_for_pending(self, result, bucket=None, timeout=None,
  250. interval=0.5, no_ack=True, on_interval=None,
  251. on_message=None, callback=None, propagate=True):
  252. return self.result_consumer.collect_for_pending(
  253. result, bucket=bucket, timeout=timeout, interval=interval,
  254. no_ack=no_ack, on_interval=on_interval,
  255. callback=callback, on_message=on_message, propagate=propagate,
  256. )
  257. def add_pending_result(self, result):
  258. if result.id not in self._pending_results:
  259. self._pending_results[result.id] = result
  260. self.result_consumer.consume_from(self._create_binding(result.id))
  261. def remove_pending_result(self, result):
  262. self._pending_results.pop(result.id, None)
  263. # XXX cancel queue after result consumed
  264. def _many_bindings(self, ids):
  265. return [self._create_binding(task_id) for task_id in ids]
  266. def xxx_get_many(self, task_ids, timeout=None, no_ack=True,
  267. on_message=None, on_interval=None,
  268. now=monotonic, getfields=itemgetter('status', 'task_id'),
  269. READY_STATES=states.READY_STATES,
  270. PROPAGATE_STATES=states.PROPAGATE_STATES, **kwargs):
  271. with self.app.pool.acquire_channel(block=True) as (conn, channel):
  272. ids = set(task_ids)
  273. cached_ids = set()
  274. mark_cached = cached_ids.add
  275. for task_id in ids:
  276. try:
  277. cached = self._cache[task_id]
  278. except KeyError:
  279. pass
  280. else:
  281. if cached['status'] in READY_STATES:
  282. yield task_id, cached
  283. mark_cached(task_id)
  284. ids.difference_update(cached_ids)
  285. results = deque()
  286. push_result = results.append
  287. push_cache = self._cache.__setitem__
  288. decode_result = self.meta_from_decoded
  289. def _on_message(message):
  290. body = decode_result(message.decode())
  291. if on_message is not None:
  292. on_message(body)
  293. state, uid = getfields(body)
  294. if state in READY_STATES:
  295. push_result(body) \
  296. if uid in task_ids else push_cache(uid, body)
  297. bindings = self._many_bindings(task_ids)
  298. with self.Consumer(channel, bindings, on_message=_on_message,
  299. accept=self.accept, no_ack=no_ack):
  300. wait = conn.drain_events
  301. popleft = results.popleft
  302. while ids:
  303. wait(timeout=timeout)
  304. while results:
  305. state = popleft()
  306. task_id = state['task_id']
  307. ids.discard(task_id)
  308. push_cache(task_id, state)
  309. yield task_id, state
  310. if on_interval:
  311. on_interval()
  312. def reload_task_result(self, task_id):
  313. raise NotImplementedError(
  314. 'reload_task_result is not supported by this backend.')
  315. def reload_group_result(self, task_id):
  316. """Reload group result, even if it has been previously fetched."""
  317. raise NotImplementedError(
  318. 'reload_group_result is not supported by this backend.')
  319. def save_group(self, group_id, result):
  320. raise NotImplementedError(
  321. 'save_group is not supported by this backend.')
  322. def restore_group(self, group_id, cache=True):
  323. raise NotImplementedError(
  324. 'restore_group is not supported by this backend.')
  325. def delete_group(self, group_id):
  326. raise NotImplementedError(
  327. 'delete_group is not supported by this backend.')
  328. def __reduce__(self, args=(), kwargs={}):
  329. kwargs.update(
  330. connection=self._connection,
  331. exchange=self.exchange.name,
  332. exchange_type=self.exchange.type,
  333. persistent=self.persistent,
  334. serializer=self.serializer,
  335. auto_delete=self.auto_delete,
  336. expires=self.expires,
  337. )
  338. return super(AMQPBackend, self).__reduce__(args, kwargs)