test_amqp.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. from __future__ import absolute_import
  2. import socket
  3. from datetime import timedelta
  4. from Queue import Empty, Queue
  5. from celery import current_app
  6. from celery import states
  7. from celery.app import app_or_default
  8. from celery.backends.amqp import AMQPBackend
  9. from celery.datastructures import ExceptionInfo
  10. from celery.exceptions import TimeoutError
  11. from celery.utils import uuid
  12. from celery.tests.utils import AppCase, sleepdeprived
  13. class SomeClass(object):
  14. def __init__(self, data):
  15. self.data = data
  16. class test_AMQPBackend(AppCase):
  17. def create_backend(self, **opts):
  18. opts = dict(dict(serializer='pickle', persistent=False), **opts)
  19. return AMQPBackend(**opts)
  20. def test_mark_as_done(self):
  21. tb1 = self.create_backend()
  22. tb2 = self.create_backend()
  23. tid = uuid()
  24. tb1.mark_as_done(tid, 42)
  25. self.assertEqual(tb2.get_status(tid), states.SUCCESS)
  26. self.assertEqual(tb2.get_result(tid), 42)
  27. self.assertTrue(tb2._cache.get(tid))
  28. self.assertTrue(tb2.get_result(tid), 42)
  29. def test_revive(self):
  30. tb = self.create_backend()
  31. tb.revive(None)
  32. def test_is_pickled(self):
  33. tb1 = self.create_backend()
  34. tb2 = self.create_backend()
  35. tid2 = uuid()
  36. result = {'foo': 'baz', 'bar': SomeClass(12345)}
  37. tb1.mark_as_done(tid2, result)
  38. # is serialized properly.
  39. rindb = tb2.get_result(tid2)
  40. self.assertEqual(rindb.get('foo'), 'baz')
  41. self.assertEqual(rindb.get('bar').data, 12345)
  42. def test_mark_as_failure(self):
  43. tb1 = self.create_backend()
  44. tb2 = self.create_backend()
  45. tid3 = uuid()
  46. try:
  47. raise KeyError('foo')
  48. except KeyError as exception:
  49. einfo = ExceptionInfo()
  50. tb1.mark_as_failure(tid3, exception, traceback=einfo.traceback)
  51. self.assertEqual(tb2.get_status(tid3), states.FAILURE)
  52. self.assertIsInstance(tb2.get_result(tid3), KeyError)
  53. self.assertEqual(tb2.get_traceback(tid3), einfo.traceback)
  54. def test_repair_uuid(self):
  55. from celery.backends.amqp import repair_uuid
  56. for i in range(10):
  57. tid = uuid()
  58. self.assertEqual(repair_uuid(tid.replace('-', '')), tid)
  59. def test_expires_is_int(self):
  60. b = self.create_backend(expires=48)
  61. self.assertEqual(b.queue_arguments.get('x-expires'), 48 * 1000.0)
  62. def test_expires_is_float(self):
  63. b = self.create_backend(expires=48.3)
  64. self.assertEqual(b.queue_arguments.get('x-expires'), 48.3 * 1000.0)
  65. def test_expires_is_timedelta(self):
  66. b = self.create_backend(expires=timedelta(minutes=1))
  67. self.assertEqual(b.queue_arguments.get('x-expires'), 60 * 1000.0)
  68. @sleepdeprived()
  69. def test_store_result_retries(self):
  70. iterations = [0]
  71. stop_raising_at = [5]
  72. def publish(*args, **kwargs):
  73. if iterations[0] > stop_raising_at[0]:
  74. return
  75. iterations[0] += 1
  76. raise KeyError('foo')
  77. backend = AMQPBackend()
  78. from celery.app.amqp import TaskProducer
  79. prod, TaskProducer.publish = TaskProducer.publish, publish
  80. try:
  81. with self.assertRaises(KeyError):
  82. backend.retry_policy['max_retries'] = None
  83. backend.store_result('foo', 'bar', 'STARTED')
  84. with self.assertRaises(KeyError):
  85. backend.retry_policy['max_retries'] = 10
  86. backend.store_result('foo', 'bar', 'STARTED')
  87. finally:
  88. TaskProducer.publish = prod
  89. def assertState(self, retval, state):
  90. self.assertEqual(retval['status'], state)
  91. def test_poll_no_messages(self):
  92. b = self.create_backend()
  93. self.assertState(b.get_task_meta(uuid()), states.PENDING)
  94. def test_poll_result(self):
  95. results = Queue()
  96. class Message(object):
  97. def __init__(self, **merge):
  98. self.payload = dict({'status': states.STARTED,
  99. 'result': None}, **merge)
  100. class MockBinding(object):
  101. def __init__(self, *args, **kwargs):
  102. pass
  103. def __call__(self, *args, **kwargs):
  104. return self
  105. def declare(self):
  106. pass
  107. def get(self, no_ack=False):
  108. try:
  109. return results.get(block=False)
  110. except Empty:
  111. pass
  112. class MockBackend(AMQPBackend):
  113. Queue = MockBinding
  114. backend = MockBackend()
  115. # FFWD's to the latest state.
  116. results.put(Message(status=states.RECEIVED, seq=1))
  117. results.put(Message(status=states.STARTED, seq=2))
  118. results.put(Message(status=states.FAILURE, seq=3))
  119. r1 = backend.get_task_meta(uuid())
  120. self.assertDictContainsSubset({'status': states.FAILURE,
  121. 'seq': 3}, r1,
  122. 'FFWDs to the last state')
  123. # Caches last known state.
  124. results.put(Message())
  125. tid = uuid()
  126. backend.get_task_meta(tid)
  127. self.assertIn(tid, backend._cache, 'Caches last known state')
  128. # Returns cache if no new states.
  129. results.queue.clear()
  130. assert not results.qsize()
  131. backend._cache[tid] = 'hello'
  132. self.assertEqual(backend.get_task_meta(tid), 'hello',
  133. 'Returns cache if no new states')
  134. def test_wait_for(self):
  135. b = self.create_backend()
  136. tid = uuid()
  137. with self.assertRaises(TimeoutError):
  138. b.wait_for(tid, timeout=0.1)
  139. b.store_result(tid, None, states.STARTED)
  140. with self.assertRaises(TimeoutError):
  141. b.wait_for(tid, timeout=0.1)
  142. b.store_result(tid, None, states.RETRY)
  143. with self.assertRaises(TimeoutError):
  144. b.wait_for(tid, timeout=0.1)
  145. b.store_result(tid, 42, states.SUCCESS)
  146. self.assertEqual(b.wait_for(tid, timeout=1), 42)
  147. b.store_result(tid, 56, states.SUCCESS)
  148. self.assertEqual(b.wait_for(tid, timeout=1), 42,
  149. 'result is cached')
  150. self.assertEqual(b.wait_for(tid, timeout=1, cache=False), 56)
  151. b.store_result(tid, KeyError('foo'), states.FAILURE)
  152. with self.assertRaises(KeyError):
  153. b.wait_for(tid, timeout=1, cache=False)
  154. def test_drain_events_remaining_timeouts(self):
  155. class Connection(object):
  156. def drain_events(self, timeout=None):
  157. pass
  158. b = self.create_backend()
  159. with current_app.pool.acquire_channel(block=False) as (_, channel):
  160. binding = b._create_binding(uuid())
  161. consumer = b.Consumer(channel, binding, no_ack=True)
  162. with self.assertRaises(socket.timeout):
  163. b.drain_events(Connection(), consumer, timeout=0.1)
  164. def test_get_many(self):
  165. b = self.create_backend()
  166. tids = []
  167. for i in xrange(10):
  168. tid = uuid()
  169. b.store_result(tid, i, states.SUCCESS)
  170. tids.append(tid)
  171. res = list(b.get_many(tids, timeout=1))
  172. expected_results = [(tid, {'status': states.SUCCESS,
  173. 'result': i,
  174. 'traceback': None,
  175. 'task_id': tid,
  176. 'children': None})
  177. for i, tid in enumerate(tids)]
  178. self.assertEqual(sorted(res), sorted(expected_results))
  179. self.assertDictEqual(b._cache[res[0][0]], res[0][1])
  180. cached_res = list(b.get_many(tids, timeout=1))
  181. self.assertEqual(sorted(cached_res), sorted(expected_results))
  182. b._cache[res[0][0]]['status'] = states.RETRY
  183. with self.assertRaises(socket.timeout):
  184. list(b.get_many(tids, timeout=0.01))
  185. def test_test_get_many_raises_outer_block(self):
  186. class Backend(AMQPBackend):
  187. def Consumer(*args, **kwargs):
  188. raise KeyError('foo')
  189. b = Backend()
  190. with self.assertRaises(KeyError):
  191. next(b.get_many(['id1']))
  192. def test_test_get_many_raises_inner_block(self):
  193. class Backend(AMQPBackend):
  194. def drain_events(self, *args, **kwargs):
  195. raise KeyError('foo')
  196. b = Backend()
  197. with self.assertRaises(KeyError):
  198. next(b.get_many(['id1']))
  199. def test_no_expires(self):
  200. b = self.create_backend(expires=None)
  201. app = app_or_default()
  202. prev = app.conf.CELERY_TASK_RESULT_EXPIRES
  203. app.conf.CELERY_TASK_RESULT_EXPIRES = None
  204. try:
  205. b = self.create_backend(expires=None)
  206. with self.assertRaises(KeyError):
  207. b.queue_arguments['x-expires']
  208. finally:
  209. app.conf.CELERY_TASK_RESULT_EXPIRES = prev
  210. def test_process_cleanup(self):
  211. self.create_backend().process_cleanup()
  212. def test_reload_task_result(self):
  213. with self.assertRaises(NotImplementedError):
  214. self.create_backend().reload_task_result('x')
  215. def test_reload_group_result(self):
  216. with self.assertRaises(NotImplementedError):
  217. self.create_backend().reload_group_result('x')
  218. def test_save_group(self):
  219. with self.assertRaises(NotImplementedError):
  220. self.create_backend().save_group('x', 'x')
  221. def test_restore_group(self):
  222. with self.assertRaises(NotImplementedError):
  223. self.create_backend().restore_group('x')
  224. def test_delete_group(self):
  225. with self.assertRaises(NotImplementedError):
  226. self.create_backend().delete_group('x')