test_amqp.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. from __future__ import absolute_import
  2. import pickle
  3. import socket
  4. from contextlib import contextmanager
  5. from datetime import timedelta
  6. from pickle import dumps, loads
  7. from billiard.einfo import ExceptionInfo
  8. from celery import states
  9. from celery.backends.amqp import AMQPBackend
  10. from celery.exceptions import TimeoutError
  11. from celery.five import Empty, Queue, range
  12. from celery.result import AsyncResult
  13. from celery.utils import uuid
  14. from celery.tests.case import (
  15. AppCase, Mock, depends_on_current_app, patch, sleepdeprived,
  16. )
  17. class SomeClass(object):
  18. def __init__(self, data):
  19. self.data = data
  20. class test_AMQPBackend(AppCase):
  21. def create_backend(self, **opts):
  22. opts = dict(dict(serializer='pickle', persistent=True), **opts)
  23. return AMQPBackend(self.app, **opts)
  24. def test_mark_as_done(self):
  25. tb1 = self.create_backend(max_cached_results=1)
  26. tb2 = self.create_backend(max_cached_results=1)
  27. tid = uuid()
  28. tb1.mark_as_done(tid, 42)
  29. self.assertEqual(tb2.get_status(tid), states.SUCCESS)
  30. self.assertEqual(tb2.get_result(tid), 42)
  31. self.assertTrue(tb2._cache.get(tid))
  32. self.assertTrue(tb2.get_result(tid), 42)
  33. @depends_on_current_app
  34. def test_pickleable(self):
  35. self.assertTrue(loads(dumps(self.create_backend())))
  36. def test_revive(self):
  37. tb = self.create_backend()
  38. tb.revive(None)
  39. def test_is_pickled(self):
  40. tb1 = self.create_backend()
  41. tb2 = self.create_backend()
  42. tid2 = uuid()
  43. result = {'foo': 'baz', 'bar': SomeClass(12345)}
  44. tb1.mark_as_done(tid2, result)
  45. # is serialized properly.
  46. rindb = tb2.get_result(tid2)
  47. self.assertEqual(rindb.get('foo'), 'baz')
  48. self.assertEqual(rindb.get('bar').data, 12345)
  49. def test_mark_as_failure(self):
  50. tb1 = self.create_backend()
  51. tb2 = self.create_backend()
  52. tid3 = uuid()
  53. try:
  54. raise KeyError('foo')
  55. except KeyError as exception:
  56. einfo = ExceptionInfo()
  57. tb1.mark_as_failure(tid3, exception, traceback=einfo.traceback)
  58. self.assertEqual(tb2.get_status(tid3), states.FAILURE)
  59. self.assertIsInstance(tb2.get_result(tid3), KeyError)
  60. self.assertEqual(tb2.get_traceback(tid3), einfo.traceback)
  61. def test_repair_uuid(self):
  62. from celery.backends.amqp import repair_uuid
  63. for i in range(10):
  64. tid = uuid()
  65. self.assertEqual(repair_uuid(tid.replace('-', '')), tid)
  66. def test_expires_is_int(self):
  67. b = self.create_backend(expires=48)
  68. self.assertEqual(b.queue_arguments.get('x-expires'), 48 * 1000.0)
  69. def test_expires_is_float(self):
  70. b = self.create_backend(expires=48.3)
  71. self.assertEqual(b.queue_arguments.get('x-expires'), 48.3 * 1000.0)
  72. def test_expires_is_timedelta(self):
  73. b = self.create_backend(expires=timedelta(minutes=1))
  74. self.assertEqual(b.queue_arguments.get('x-expires'), 60 * 1000.0)
  75. @sleepdeprived()
  76. def test_store_result_retries(self):
  77. iterations = [0]
  78. stop_raising_at = [5]
  79. def publish(*args, **kwargs):
  80. if iterations[0] > stop_raising_at[0]:
  81. return
  82. iterations[0] += 1
  83. raise KeyError('foo')
  84. backend = AMQPBackend(self.app)
  85. from celery.app.amqp import Producer
  86. prod, Producer.publish = Producer.publish, publish
  87. try:
  88. with self.assertRaises(KeyError):
  89. backend.retry_policy['max_retries'] = None
  90. backend.store_result('foo', 'bar', 'STARTED')
  91. with self.assertRaises(KeyError):
  92. backend.retry_policy['max_retries'] = 10
  93. backend.store_result('foo', 'bar', 'STARTED')
  94. finally:
  95. Producer.publish = prod
  96. def assertState(self, retval, state):
  97. self.assertEqual(retval['status'], state)
  98. def test_poll_no_messages(self):
  99. b = self.create_backend()
  100. self.assertState(b.get_task_meta(uuid()), states.PENDING)
  101. @contextmanager
  102. def _result_context(self):
  103. results = Queue()
  104. class Message(object):
  105. acked = 0
  106. requeued = 0
  107. def __init__(self, **merge):
  108. self.payload = dict({'status': states.STARTED,
  109. 'result': None}, **merge)
  110. self.body = pickle.dumps(self.payload)
  111. self.content_type = 'application/x-python-serialize'
  112. self.content_encoding = 'binary'
  113. def ack(self, *args, **kwargs):
  114. self.acked += 1
  115. def requeue(self, *args, **kwargs):
  116. self.requeued += 1
  117. class MockBinding(object):
  118. def __init__(self, *args, **kwargs):
  119. self.channel = Mock()
  120. def __call__(self, *args, **kwargs):
  121. return self
  122. def declare(self):
  123. pass
  124. def get(self, no_ack=False, accept=None):
  125. try:
  126. m = results.get(block=False)
  127. if m:
  128. m.accept = accept
  129. return m
  130. except Empty:
  131. pass
  132. def is_bound(self):
  133. return True
  134. class MockBackend(AMQPBackend):
  135. Queue = MockBinding
  136. backend = MockBackend(self.app, max_cached_results=100)
  137. backend._republish = Mock()
  138. yield results, backend, Message
  139. def test_backlog_limit_exceeded(self):
  140. with self._result_context() as (results, backend, Message):
  141. for i in range(1001):
  142. results.put(Message(task_id='id', status=states.RECEIVED))
  143. with self.assertRaises(backend.BacklogLimitExceeded):
  144. backend.get_task_meta('id')
  145. def test_poll_result(self):
  146. with self._result_context() as (results, backend, Message):
  147. tid = uuid()
  148. # FFWD's to the latest state.
  149. state_messages = [
  150. Message(task_id=tid, status=states.RECEIVED, seq=1),
  151. Message(task_id=tid, status=states.STARTED, seq=2),
  152. Message(task_id=tid, status=states.FAILURE, seq=3),
  153. ]
  154. for state_message in state_messages:
  155. results.put(state_message)
  156. r1 = backend.get_task_meta(tid)
  157. self.assertDictContainsSubset(
  158. {'status': states.FAILURE, 'seq': 3}, r1,
  159. 'FFWDs to the last state',
  160. )
  161. # Caches last known state.
  162. tid = uuid()
  163. results.put(Message(task_id=tid))
  164. backend.get_task_meta(tid)
  165. self.assertIn(tid, backend._cache, 'Caches last known state')
  166. self.assertTrue(state_messages[-1].requeued)
  167. # Returns cache if no new states.
  168. results.queue.clear()
  169. assert not results.qsize()
  170. backend._cache[tid] = 'hello'
  171. self.assertEqual(
  172. backend.get_task_meta(tid), 'hello',
  173. 'Returns cache if no new states',
  174. )
  175. def test_wait_for(self):
  176. b = self.create_backend()
  177. tid = uuid()
  178. with self.assertRaises(TimeoutError):
  179. b.wait_for(tid, timeout=0.1)
  180. b.store_result(tid, None, states.STARTED)
  181. with self.assertRaises(TimeoutError):
  182. b.wait_for(tid, timeout=0.1)
  183. b.store_result(tid, None, states.RETRY)
  184. with self.assertRaises(TimeoutError):
  185. b.wait_for(tid, timeout=0.1)
  186. b.store_result(tid, 42, states.SUCCESS)
  187. self.assertEqual(b.wait_for(tid, timeout=1)['result'], 42)
  188. b.store_result(tid, 56, states.SUCCESS)
  189. self.assertEqual(b.wait_for(tid, timeout=1)['result'], 42,
  190. 'result is cached')
  191. self.assertEqual(b.wait_for(tid, timeout=1, cache=False)['result'], 56)
  192. b.store_result(tid, KeyError('foo'), states.FAILURE)
  193. res = b.wait_for(tid, timeout=1, cache=False)
  194. self.assertEqual(res['status'], states.FAILURE)
  195. b.store_result(tid, KeyError('foo'), states.PENDING)
  196. with self.assertRaises(TimeoutError):
  197. b.wait_for(tid, timeout=0.01, cache=False)
  198. def test_drain_events_decodes_exceptions_in_meta(self):
  199. tid = uuid()
  200. b = self.create_backend(serializer="json")
  201. b.store_result(tid, RuntimeError("aap"), states.FAILURE)
  202. result = AsyncResult(tid, backend=b)
  203. with self.assertRaises(Exception) as cm:
  204. result.get()
  205. self.assertEqual(cm.exception.__class__.__name__, "RuntimeError")
  206. self.assertEqual(str(cm.exception), "aap")
  207. def test_drain_events_remaining_timeouts(self):
  208. class Connection(object):
  209. def drain_events(self, timeout=None):
  210. pass
  211. b = self.create_backend()
  212. with self.app.pool.acquire_channel(block=False) as (_, channel):
  213. binding = b._create_binding(uuid())
  214. consumer = b.Consumer(channel, binding, no_ack=True)
  215. with self.assertRaises(socket.timeout):
  216. b.drain_events(Connection(), consumer, timeout=0.1)
  217. def test_get_many(self):
  218. b = self.create_backend(max_cached_results=10)
  219. tids = []
  220. for i in range(10):
  221. tid = uuid()
  222. b.store_result(tid, i, states.SUCCESS)
  223. tids.append(tid)
  224. res = list(b.get_many(tids, timeout=1))
  225. expected_results = [
  226. (_tid, {'status': states.SUCCESS,
  227. 'result': i,
  228. 'traceback': None,
  229. 'task_id': _tid,
  230. 'children': None})
  231. for i, _tid in enumerate(tids)
  232. ]
  233. self.assertEqual(sorted(res), sorted(expected_results))
  234. self.assertDictEqual(b._cache[res[0][0]], res[0][1])
  235. cached_res = list(b.get_many(tids, timeout=1))
  236. self.assertEqual(sorted(cached_res), sorted(expected_results))
  237. # times out when not ready in cache (this shouldn't happen)
  238. b._cache[res[0][0]]['status'] = states.RETRY
  239. with self.assertRaises(socket.timeout):
  240. list(b.get_many(tids, timeout=0.01))
  241. # times out when result not yet ready
  242. with self.assertRaises(socket.timeout):
  243. tids = [uuid()]
  244. b.store_result(tids[0], i, states.PENDING)
  245. list(b.get_many(tids, timeout=0.01))
  246. def test_get_many_on_message(self):
  247. b = self.create_backend(max_cached_results=10)
  248. tids = []
  249. for i in range(10):
  250. tid = uuid()
  251. b.store_result(tid, '', states.PENDING)
  252. b.store_result(tid, 'comment_%i_1' % i, states.STARTED)
  253. b.store_result(tid, 'comment_%i_2' % i, states.STARTED)
  254. b.store_result(tid, 'final result %i' % i, states.SUCCESS)
  255. tids.append(tid)
  256. expected_messages = {}
  257. for i, _tid in enumerate(tids):
  258. expected_messages[_tid] = []
  259. expected_messages[_tid].append( (states.PENDING, '') )
  260. expected_messages[_tid].append( (states.STARTED, 'comment_%i_1' % i) )
  261. expected_messages[_tid].append( (states.STARTED, 'comment_%i_2' % i) )
  262. expected_messages[_tid].append( (states.SUCCESS, 'final result %i' % i) )
  263. on_message_results = {}
  264. def on_message(body):
  265. if not body['task_id'] in on_message_results:
  266. on_message_results[body['task_id']] = []
  267. on_message_results[body['task_id']].append( (body['status'], body['result']) )
  268. res = list(b.get_many(tids, timeout=1, on_message=on_message))
  269. self.assertEqual(sorted(on_message_results), sorted(expected_messages))
  270. def test_get_many_raises_outer_block(self):
  271. class Backend(AMQPBackend):
  272. def Consumer(*args, **kwargs):
  273. raise KeyError('foo')
  274. b = Backend(self.app)
  275. with self.assertRaises(KeyError):
  276. next(b.get_many(['id1']))
  277. def test_get_many_raises_inner_block(self):
  278. with patch('kombu.connection.Connection.drain_events') as drain:
  279. drain.side_effect = KeyError('foo')
  280. b = AMQPBackend(self.app)
  281. with self.assertRaises(KeyError):
  282. next(b.get_many(['id1']))
  283. def test_consume_raises_inner_block(self):
  284. with patch('kombu.connection.Connection.drain_events') as drain:
  285. def se(*args, **kwargs):
  286. drain.side_effect = ValueError()
  287. raise KeyError('foo')
  288. drain.side_effect = se
  289. b = AMQPBackend(self.app)
  290. with self.assertRaises(ValueError):
  291. next(b.consume('id1'))
  292. def test_no_expires(self):
  293. b = self.create_backend(expires=None)
  294. app = self.app
  295. app.conf.CELERY_TASK_RESULT_EXPIRES = None
  296. b = self.create_backend(expires=None)
  297. with self.assertRaises(KeyError):
  298. b.queue_arguments['x-expires']
  299. def test_process_cleanup(self):
  300. self.create_backend().process_cleanup()
  301. def test_reload_task_result(self):
  302. with self.assertRaises(NotImplementedError):
  303. self.create_backend().reload_task_result('x')
  304. def test_reload_group_result(self):
  305. with self.assertRaises(NotImplementedError):
  306. self.create_backend().reload_group_result('x')
  307. def test_save_group(self):
  308. with self.assertRaises(NotImplementedError):
  309. self.create_backend().save_group('x', 'x')
  310. def test_restore_group(self):
  311. with self.assertRaises(NotImplementedError):
  312. self.create_backend().restore_group('x')
  313. def test_delete_group(self):
  314. with self.assertRaises(NotImplementedError):
  315. self.create_backend().delete_group('x')