amqp.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. # -*- coding: utf-8 -*-
  2. """
  3. celery.backends.amqp
  4. ~~~~~~~~~~~~~~~~~~~~
  5. The AMQP result backend.
  6. This backend publishes results as messages.
  7. """
  8. from __future__ import absolute_import
  9. import socket
  10. import threading
  11. import time
  12. from collections import deque
  13. from kombu import Exchange, Queue, Producer, Consumer
  14. from celery import states
  15. from celery.exceptions import TimeoutError
  16. from celery.five import range
  17. from celery.utils.log import get_logger
  18. from .base import BaseBackend
  19. logger = get_logger(__name__)
  20. class BacklogLimitExceeded(Exception):
  21. """Too much state history to fast-forward."""
  22. def repair_uuid(s):
  23. # Historically the dashes in UUIDS are removed from AMQ entity names,
  24. # but there is no known reason to. Hopefully we'll be able to fix
  25. # this in v4.0.
  26. return '%s-%s-%s-%s-%s' % (s[:8], s[8:12], s[12:16], s[16:20], s[20:])
  27. class AMQPBackend(BaseBackend):
  28. """Publishes results by sending messages."""
  29. Exchange = Exchange
  30. Queue = Queue
  31. Consumer = Consumer
  32. Producer = Producer
  33. BacklogLimitExceeded = BacklogLimitExceeded
  34. supports_native_join = True
  35. retry_policy = {
  36. 'max_retries': 20,
  37. 'interval_start': 0,
  38. 'interval_step': 1,
  39. 'interval_max': 1,
  40. }
  41. def __init__(self, connection=None, exchange=None, exchange_type=None,
  42. persistent=None, serializer=None, auto_delete=True,
  43. **kwargs):
  44. super(AMQPBackend, self).__init__(**kwargs)
  45. conf = self.app.conf
  46. self._connection = connection
  47. self.queue_arguments = {}
  48. self.persistent = (conf.CELERY_RESULT_PERSISTENT if persistent is None
  49. else persistent)
  50. exchange = exchange or conf.CELERY_RESULT_EXCHANGE
  51. exchange_type = exchange_type or conf.CELERY_RESULT_EXCHANGE_TYPE
  52. self.exchange = self._create_exchange(exchange, exchange_type,
  53. self.persistent)
  54. self.serializer = serializer or conf.CELERY_RESULT_SERIALIZER
  55. self.auto_delete = auto_delete
  56. self.expires = None
  57. if 'expires' not in kwargs or kwargs['expires'] is not None:
  58. self.expires = self.prepare_expires(kwargs.get('expires'))
  59. if self.expires:
  60. self.queue_arguments['x-expires'] = int(self.expires * 1000)
  61. self.mutex = threading.Lock()
  62. def _create_exchange(self, name, type='direct', persistent=True):
  63. delivery_mode = persistent and 'persistent' or 'transient'
  64. return self.Exchange(name=name,
  65. type=type,
  66. delivery_mode=delivery_mode,
  67. durable=self.persistent,
  68. auto_delete=False)
  69. def _create_binding(self, task_id):
  70. name = task_id.replace('-', '')
  71. return self.Queue(name=name,
  72. exchange=self.exchange,
  73. routing_key=name,
  74. durable=self.persistent,
  75. auto_delete=self.auto_delete,
  76. queue_arguments=self.queue_arguments)
  77. def revive(self, channel):
  78. pass
  79. def _routing_key(self, task_id):
  80. return task_id.replace('-', '')
  81. def _republish(self, channel, task_id, body, content_type,
  82. content_encoding):
  83. return Producer(channel).publish(
  84. body,
  85. exchange=self.exchange,
  86. routing_key=self._routing_key(task_id),
  87. serializer=self.serializer,
  88. content_type=content_type,
  89. content_encoding=content_encoding,
  90. retry=True, retry_policy=self.retry_policy,
  91. declare=self.on_reply_declare(task_id),
  92. )
  93. def _store_result(self, task_id, result, status, traceback=None):
  94. """Send task return value and status."""
  95. with self.mutex:
  96. with self.app.amqp.producer_pool.acquire(block=True) as pub:
  97. pub.publish({'task_id': task_id, 'status': status,
  98. 'result': self.encode_result(result, status),
  99. 'traceback': traceback,
  100. 'children': self.current_task_children()},
  101. exchange=self.exchange,
  102. routing_key=self._routing_key(task_id),
  103. serializer=self.serializer,
  104. retry=True, retry_policy=self.retry_policy,
  105. declare=self.on_reply_declare(task_id))
  106. return result
  107. def on_reply_declare(self, task_id):
  108. return [self._create_binding(task_id)]
  109. def wait_for(self, task_id, timeout=None, cache=True, propagate=True,
  110. **kwargs):
  111. cached_meta = self._cache.get(task_id)
  112. if cache and cached_meta and \
  113. cached_meta['status'] in states.READY_STATES:
  114. meta = cached_meta
  115. else:
  116. try:
  117. meta = self.consume(task_id, timeout=timeout)
  118. except socket.timeout:
  119. raise TimeoutError('The operation timed out.')
  120. state = meta['status']
  121. if state == states.SUCCESS:
  122. return meta['result']
  123. elif state in states.PROPAGATE_STATES:
  124. if propagate:
  125. raise self.exception_to_python(meta['result'])
  126. return meta['result']
  127. else:
  128. return self.wait_for(task_id, timeout, cache)
  129. def get_task_meta(self, task_id, backlog_limit=1000):
  130. # Polling and using basic_get
  131. with self.app.pool.acquire_channel(block=True) as (_, channel):
  132. binding = self._create_binding(task_id)(channel)
  133. binding.declare()
  134. prev = latest = acc = None
  135. for i in range(backlog_limit): # spool ffwd
  136. prev, latest, acc = latest, acc, binding.get(no_ack=False)
  137. if not acc: # no more messages
  138. break
  139. if prev:
  140. # backends are not expected to keep history,
  141. # so we delete everything except the most recent state.
  142. prev.ack()
  143. else:
  144. raise self.BacklogLimitExceeded(task_id)
  145. if latest:
  146. payload = self._cache[task_id] = latest.payload
  147. latest.requeue()
  148. return payload
  149. else:
  150. # no new state, use previous
  151. try:
  152. return self._cache[task_id]
  153. except KeyError:
  154. # result probably pending.
  155. return {'status': states.PENDING, 'result': None}
  156. poll = get_task_meta # XXX compat
  157. def drain_events(self, connection, consumer,
  158. timeout=None, now=time.time, wait=None):
  159. wait = wait or connection.drain_events
  160. results = {}
  161. def callback(meta, message):
  162. if meta['status'] in states.READY_STATES:
  163. results[meta['task_id']] = meta
  164. consumer.callbacks[:] = [callback]
  165. time_start = now()
  166. while 1:
  167. # Total time spent may exceed a single call to wait()
  168. if timeout and now() - time_start >= timeout:
  169. raise socket.timeout()
  170. wait(timeout=timeout)
  171. if results: # got event on the wanted channel.
  172. break
  173. self._cache.update(results)
  174. return results
  175. def consume(self, task_id, timeout=None):
  176. wait = self.drain_events
  177. with self.app.pool.acquire_channel(block=True) as (conn, channel):
  178. binding = self._create_binding(task_id)
  179. with self.Consumer(channel, binding, no_ack=True) as consumer:
  180. while 1:
  181. try:
  182. return wait(conn, consumer, timeout)[task_id]
  183. except KeyError:
  184. continue
  185. def _many_bindings(self, ids):
  186. return [self._create_binding(task_id) for task_id in ids]
  187. def get_many(self, task_ids, timeout=None, now=time.time, **kwargs):
  188. with self.app.pool.acquire_channel(block=True) as (conn, channel):
  189. ids = set(task_ids)
  190. cached_ids = set()
  191. for task_id in ids:
  192. try:
  193. cached = self._cache[task_id]
  194. except KeyError:
  195. pass
  196. else:
  197. if cached['status'] in states.READY_STATES:
  198. yield task_id, cached
  199. cached_ids.add(task_id)
  200. ids.difference_update(cached_ids)
  201. results = deque()
  202. def callback(meta, message):
  203. if meta['status'] in states.READY_STATES:
  204. task_id = meta['task_id']
  205. if task_id in task_ids:
  206. results.append(meta)
  207. else:
  208. self._cache[task_id] = meta
  209. bindings = self._many_bindings(task_ids)
  210. with self.Consumer(channel, bindings,
  211. callbacks=[callback], no_ack=True):
  212. wait = conn.drain_events
  213. popleft = results.popleft
  214. while ids:
  215. wait(timeout=timeout)
  216. while results:
  217. meta = popleft()
  218. task_id = meta['task_id']
  219. ids.discard(task_id)
  220. self._cache[task_id] = meta
  221. yield task_id, meta
  222. def reload_task_result(self, task_id):
  223. raise NotImplementedError(
  224. 'reload_task_result is not supported by this backend.')
  225. def reload_group_result(self, task_id):
  226. """Reload group result, even if it has been previously fetched."""
  227. raise NotImplementedError(
  228. 'reload_group_result is not supported by this backend.')
  229. def save_group(self, group_id, result):
  230. raise NotImplementedError(
  231. 'save_group is not supported by this backend.')
  232. def restore_group(self, group_id, cache=True):
  233. raise NotImplementedError(
  234. 'restore_group is not supported by this backend.')
  235. def delete_group(self, group_id):
  236. raise NotImplementedError(
  237. 'delete_group is not supported by this backend.')
  238. def __reduce__(self, args=(), kwargs={}):
  239. kwargs.update(
  240. connection=self._connection,
  241. exchange=self.exchange.name,
  242. exchange_type=self.exchange.type,
  243. persistent=self.persistent,
  244. serializer=self.serializer,
  245. auto_delete=self.auto_delete,
  246. expires=self.expires,
  247. )
  248. return super(AMQPBackend, self).__reduce__(args, kwargs)