test_amqp.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. from __future__ import absolute_import, unicode_literals
  2. import pickle
  3. import pytest
  4. from contextlib import contextmanager
  5. from datetime import timedelta
  6. from pickle import dumps, loads
  7. from case import Mock, mock
  8. from billiard.einfo import ExceptionInfo
  9. from celery import states
  10. from celery import uuid
  11. from celery.backends.amqp import AMQPBackend
  12. from celery.five import Empty, Queue, range
  13. from celery.result import AsyncResult
  14. class SomeClass(object):
  15. def __init__(self, data):
  16. self.data = data
  17. class test_AMQPBackend:
  18. def setup(self):
  19. self.app.conf.result_cache_max = 100
  20. def create_backend(self, **opts):
  21. opts = dict(dict(serializer='pickle', persistent=True), **opts)
  22. return AMQPBackend(self.app, **opts)
  23. def test_destination_for(self):
  24. b = self.create_backend()
  25. request = Mock()
  26. assert b.destination_for('id', request) == (
  27. b.rkey('id'), request.correlation_id,
  28. )
  29. def test_store_result__no_routing_key(self):
  30. b = self.create_backend()
  31. b.destination_for = Mock()
  32. b.destination_for.return_value = None, None
  33. b.store_result('id', None, states.SUCCESS)
  34. def test_mark_as_done(self):
  35. tb1 = self.create_backend(max_cached_results=1)
  36. tb2 = self.create_backend(max_cached_results=1)
  37. tid = uuid()
  38. tb1.mark_as_done(tid, 42)
  39. assert tb2.get_state(tid) == states.SUCCESS
  40. assert tb2.get_result(tid) == 42
  41. assert tb2._cache.get(tid)
  42. assert tb2.get_result(tid), 42
  43. @pytest.mark.usefixtures('depends_on_current_app')
  44. def test_pickleable(self):
  45. assert loads(dumps(self.create_backend()))
  46. def test_revive(self):
  47. tb = self.create_backend()
  48. tb.revive(None)
  49. def test_is_pickled(self):
  50. tb1 = self.create_backend()
  51. tb2 = self.create_backend()
  52. tid2 = uuid()
  53. result = {'foo': 'baz', 'bar': SomeClass(12345)}
  54. tb1.mark_as_done(tid2, result)
  55. # is serialized properly.
  56. rindb = tb2.get_result(tid2)
  57. assert rindb.get('foo') == 'baz'
  58. assert rindb.get('bar').data == 12345
  59. def test_mark_as_failure(self):
  60. tb1 = self.create_backend()
  61. tb2 = self.create_backend()
  62. tid3 = uuid()
  63. try:
  64. raise KeyError('foo')
  65. except KeyError as exception:
  66. einfo = ExceptionInfo()
  67. tb1.mark_as_failure(tid3, exception, traceback=einfo.traceback)
  68. assert tb2.get_state(tid3) == states.FAILURE
  69. assert isinstance(tb2.get_result(tid3), KeyError)
  70. assert tb2.get_traceback(tid3) == einfo.traceback
  71. def test_repair_uuid(self):
  72. from celery.backends.amqp import repair_uuid
  73. for i in range(10):
  74. tid = uuid()
  75. assert repair_uuid(tid.replace('-', '')) == tid
  76. def test_expires_is_int(self):
  77. b = self.create_backend(expires=48)
  78. q = b._create_binding('x1y2z3')
  79. assert q.expires == 48
  80. def test_expires_is_float(self):
  81. b = self.create_backend(expires=48.3)
  82. q = b._create_binding('x1y2z3')
  83. assert q.expires == 48.3
  84. def test_expires_is_timedelta(self):
  85. b = self.create_backend(expires=timedelta(minutes=1))
  86. q = b._create_binding('x1y2z3')
  87. assert q.expires == 60
  88. @mock.sleepdeprived()
  89. def test_store_result_retries(self):
  90. iterations = [0]
  91. stop_raising_at = [5]
  92. def publish(*args, **kwargs):
  93. if iterations[0] > stop_raising_at[0]:
  94. return
  95. iterations[0] += 1
  96. raise KeyError('foo')
  97. backend = AMQPBackend(self.app)
  98. from celery.app.amqp import Producer
  99. prod, Producer.publish = Producer.publish, publish
  100. try:
  101. with pytest.raises(KeyError):
  102. backend.retry_policy['max_retries'] = None
  103. backend.store_result('foo', 'bar', 'STARTED')
  104. with pytest.raises(KeyError):
  105. backend.retry_policy['max_retries'] = 10
  106. backend.store_result('foo', 'bar', 'STARTED')
  107. finally:
  108. Producer.publish = prod
  109. def test_poll_no_messages(self):
  110. b = self.create_backend()
  111. assert b.get_task_meta(uuid())['status'] == states.PENDING
  112. @contextmanager
  113. def _result_context(self):
  114. results = Queue()
  115. class Message(object):
  116. acked = 0
  117. requeued = 0
  118. def __init__(self, **merge):
  119. self.payload = dict({'status': states.STARTED,
  120. 'result': None}, **merge)
  121. self.properties = {'correlation_id': merge.get('task_id')}
  122. self.body = pickle.dumps(self.payload)
  123. self.content_type = 'application/x-python-serialize'
  124. self.content_encoding = 'binary'
  125. def ack(self, *args, **kwargs):
  126. self.acked += 1
  127. def requeue(self, *args, **kwargs):
  128. self.requeued += 1
  129. class MockBinding(object):
  130. def __init__(self, *args, **kwargs):
  131. self.channel = Mock()
  132. def __call__(self, *args, **kwargs):
  133. return self
  134. def declare(self):
  135. pass
  136. def get(self, no_ack=False, accept=None):
  137. try:
  138. m = results.get(block=False)
  139. if m:
  140. m.accept = accept
  141. return m
  142. except Empty:
  143. pass
  144. def is_bound(self):
  145. return True
  146. class MockBackend(AMQPBackend):
  147. Queue = MockBinding
  148. backend = MockBackend(self.app, max_cached_results=100)
  149. backend._republish = Mock()
  150. yield results, backend, Message
  151. def test_backlog_limit_exceeded(self):
  152. with self._result_context() as (results, backend, Message):
  153. for i in range(1001):
  154. results.put(Message(task_id='id', status=states.RECEIVED))
  155. with pytest.raises(backend.BacklogLimitExceeded):
  156. backend.get_task_meta('id')
  157. def test_poll_result(self):
  158. with self._result_context() as (results, backend, Message):
  159. tid = uuid()
  160. # FFWD's to the latest state.
  161. state_messages = [
  162. Message(task_id=tid, status=states.RECEIVED, seq=1),
  163. Message(task_id=tid, status=states.STARTED, seq=2),
  164. Message(task_id=tid, status=states.FAILURE, seq=3),
  165. ]
  166. for state_message in state_messages:
  167. results.put(state_message)
  168. r1 = backend.get_task_meta(tid)
  169. # FFWDs to the last state.
  170. assert r1['status'] == states.FAILURE
  171. assert r1['seq'] == 3
  172. # Caches last known state.
  173. tid = uuid()
  174. results.put(Message(task_id=tid))
  175. backend.get_task_meta(tid)
  176. assert tid, backend._cache in 'Caches last known state'
  177. assert state_messages[-1].requeued
  178. # Returns cache if no new states.
  179. results.queue.clear()
  180. assert not results.qsize()
  181. backend._cache[tid] = 'hello'
  182. # returns cache if no new states.
  183. assert backend.get_task_meta(tid) == 'hello'
  184. def test_drain_events_decodes_exceptions_in_meta(self):
  185. tid = uuid()
  186. b = self.create_backend(serializer='json')
  187. b.store_result(tid, RuntimeError('aap'), states.FAILURE)
  188. result = AsyncResult(tid, backend=b)
  189. with pytest.raises(Exception) as excinfo:
  190. result.get()
  191. assert excinfo.value.__class__.__name__ == 'RuntimeError'
  192. assert str(excinfo.value) == 'aap'
  193. def test_no_expires(self):
  194. b = self.create_backend(expires=None)
  195. app = self.app
  196. app.conf.result_expires = None
  197. b = self.create_backend(expires=None)
  198. q = b._create_binding('foo')
  199. assert q.expires is None
  200. def test_process_cleanup(self):
  201. self.create_backend().process_cleanup()
  202. def test_reload_task_result(self):
  203. with pytest.raises(NotImplementedError):
  204. self.create_backend().reload_task_result('x')
  205. def test_reload_group_result(self):
  206. with pytest.raises(NotImplementedError):
  207. self.create_backend().reload_group_result('x')
  208. def test_save_group(self):
  209. with pytest.raises(NotImplementedError):
  210. self.create_backend().save_group('x', 'x')
  211. def test_restore_group(self):
  212. with pytest.raises(NotImplementedError):
  213. self.create_backend().restore_group('x')
  214. def test_delete_group(self):
  215. with pytest.raises(NotImplementedError):
  216. self.create_backend().delete_group('x')