test_amqp.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. import socket
  2. import sys
  3. from datetime import timedelta
  4. from celery import states
  5. from celery.app import app_or_default
  6. from celery.backends.amqp import AMQPBackend
  7. from celery.datastructures import ExceptionInfo
  8. from celery.exceptions import TimeoutError
  9. from celery.utils import gen_unique_id
  10. from celery.tests.utils import unittest
  11. from celery.tests.utils import sleepdeprived
  12. class SomeClass(object):
  13. def __init__(self, data):
  14. self.data = data
  15. class test_AMQPBackend(unittest.TestCase):
  16. def create_backend(self, **opts):
  17. opts = dict(dict(serializer="pickle", persistent=False), **opts)
  18. return AMQPBackend(**opts)
  19. def test_mark_as_done(self):
  20. tb1 = self.create_backend()
  21. tb2 = self.create_backend()
  22. tid = gen_unique_id()
  23. tb1.mark_as_done(tid, 42)
  24. self.assertEqual(tb2.get_status(tid), states.SUCCESS)
  25. self.assertEqual(tb2.get_result(tid), 42)
  26. self.assertTrue(tb2._cache.get(tid))
  27. self.assertTrue(tb2.get_result(tid), 42)
  28. def test_is_pickled(self):
  29. tb1 = self.create_backend()
  30. tb2 = self.create_backend()
  31. tid2 = gen_unique_id()
  32. result = {"foo": "baz", "bar": SomeClass(12345)}
  33. tb1.mark_as_done(tid2, result)
  34. # is serialized properly.
  35. rindb = tb2.get_result(tid2)
  36. self.assertEqual(rindb.get("foo"), "baz")
  37. self.assertEqual(rindb.get("bar").data, 12345)
  38. def test_mark_as_failure(self):
  39. tb1 = self.create_backend()
  40. tb2 = self.create_backend()
  41. tid3 = gen_unique_id()
  42. try:
  43. raise KeyError("foo")
  44. except KeyError, exception:
  45. einfo = ExceptionInfo(sys.exc_info())
  46. tb1.mark_as_failure(tid3, exception, traceback=einfo.traceback)
  47. self.assertEqual(tb2.get_status(tid3), states.FAILURE)
  48. self.assertIsInstance(tb2.get_result(tid3), KeyError)
  49. self.assertEqual(tb2.get_traceback(tid3), einfo.traceback)
  50. def test_repair_uuid(self):
  51. from celery.backends.amqp import repair_uuid
  52. for i in range(10):
  53. uuid = gen_unique_id()
  54. self.assertEqual(repair_uuid(uuid.replace("-", "")), uuid)
  55. def test_expires_defaults_to_config(self):
  56. app = app_or_default()
  57. prev = app.conf.CELERY_AMQP_TASK_RESULT_EXPIRES
  58. app.conf.CELERY_AMQP_TASK_RESULT_EXPIRES = 10
  59. try:
  60. b = self.create_backend(expires=None)
  61. self.assertEqual(b.queue_arguments.get("x-expires"), 10 * 1000.0)
  62. finally:
  63. app.conf.CELERY_AMQP_TASK_RESULT_EXPIRES = prev
  64. def test_expires_is_int(self):
  65. b = self.create_backend(expires=48)
  66. self.assertEqual(b.queue_arguments.get("x-expires"), 48 * 1000.0)
  67. def test_expires_is_timedelta(self):
  68. b = self.create_backend(expires=timedelta(minutes=1))
  69. self.assertEqual(b.queue_arguments.get("x-expires"), 60 * 1000.0)
  70. @sleepdeprived()
  71. def test_store_result_retries(self):
  72. class _Producer(object):
  73. iterations = 0
  74. stop_raising_at = 5
  75. def __init__(self, *args, **kwargs):
  76. pass
  77. def publish(self, msg, *args, **kwargs):
  78. if self.iterations > self.stop_raising_at:
  79. return
  80. raise KeyError("foo")
  81. class Backend(AMQPBackend):
  82. Producer = _Producer
  83. backend = Backend()
  84. self.assertRaises(KeyError, backend.store_result,
  85. "foo", "bar", "STARTED", max_retries=None)
  86. print(backend.store_result)
  87. self.assertRaises(KeyError, backend.store_result,
  88. "foo", "bar", "STARTED", max_retries=10)
  89. def assertState(self, retval, state):
  90. self.assertEqual(retval["status"], state)
  91. def test_poll_no_messages(self):
  92. b = self.create_backend()
  93. self.assertState(b.poll(gen_unique_id()), states.PENDING)
  94. def test_poll_result(self):
  95. class MockBinding(object):
  96. get_returns = [True]
  97. def __init__(self, *args, **kwargs):
  98. pass
  99. def __call__(self, *args, **kwargs):
  100. return self
  101. def declare(self):
  102. pass
  103. def get(self, no_ack=False):
  104. if self.get_returns[0]:
  105. class Object(object):
  106. payload = {"status": "STARTED",
  107. "result": None}
  108. return Object()
  109. class MockBackend(AMQPBackend):
  110. Queue = MockBinding
  111. backend = MockBackend()
  112. backend.poll(gen_unique_id())
  113. uuid = gen_unique_id()
  114. backend.poll(uuid)
  115. self.assertIn(uuid, backend._cache)
  116. MockBinding.get_returns[0] = False
  117. backend._cache[uuid] = "hello"
  118. self.assertEqual(backend.poll(uuid), "hello")
  119. def test_wait_for(self):
  120. b = self.create_backend()
  121. uuid = gen_unique_id()
  122. self.assertRaises(TimeoutError, b.wait_for, uuid, timeout=0.1)
  123. b.store_result(uuid, None, states.STARTED)
  124. self.assertRaises(TimeoutError, b.wait_for, uuid, timeout=0.1)
  125. b.store_result(uuid, None, states.RETRY)
  126. self.assertRaises(TimeoutError, b.wait_for, uuid, timeout=0.1)
  127. b.store_result(uuid, 42, states.SUCCESS)
  128. self.assertEqual(b.wait_for(uuid, timeout=1), 42)
  129. b.store_result(uuid, 56, states.SUCCESS)
  130. self.assertEqual(b.wait_for(uuid, timeout=1), 42,
  131. "result is cached")
  132. self.assertEqual(b.wait_for(uuid, timeout=1, cache=False), 56)
  133. b.store_result(uuid, KeyError("foo"), states.FAILURE)
  134. self.assertRaises(KeyError, b.wait_for, uuid, timeout=1, cache=False)
  135. def test_drain_events_remaining_timeouts(self):
  136. class Connection(object):
  137. def drain_events(self, timeout=None):
  138. pass
  139. b = self.create_backend()
  140. conn = b.pool.acquire(block=False)
  141. channel = conn.channel()
  142. try:
  143. binding = b._create_binding(gen_unique_id())
  144. consumer = b._create_consumer(binding, channel)
  145. self.assertRaises(socket.timeout, b.drain_events,
  146. Connection(), consumer, timeout=0.1)
  147. finally:
  148. channel.close()
  149. conn.release()
  150. def test_get_many(self):
  151. b = self.create_backend()
  152. uuids = []
  153. for i in xrange(10):
  154. uuid = gen_unique_id()
  155. b.store_result(uuid, i, states.SUCCESS)
  156. uuids.append(uuid)
  157. res = list(b.get_many(uuids, timeout=1))
  158. expected_results = [(uuid, {"status": states.SUCCESS,
  159. "result": i,
  160. "traceback": None,
  161. "task_id": uuid})
  162. for i, uuid in enumerate(uuids)]
  163. self.assertItemsEqual(res, expected_results)
  164. self.assertDictEqual(b._cache[res[0][0]], res[0][1])
  165. cached_res = list(b.get_many(uuids, timeout=1))
  166. self.assertItemsEqual(cached_res, expected_results)
  167. b._cache[res[0][0]]["status"] = states.RETRY
  168. self.assertRaises(socket.timeout, list,
  169. b.get_many(uuids, timeout=0.01))
  170. def test_test_get_many_raises_outer_block(self):
  171. class Backend(AMQPBackend):
  172. def _create_consumer(self, *args, **kwargs):
  173. raise KeyError("foo")
  174. b = Backend()
  175. self.assertRaises(KeyError, b.get_many(["id1"]).next)
  176. def test_test_get_many_raises_inner_block(self):
  177. class Backend(AMQPBackend):
  178. def drain_events(self, *args, **kwargs):
  179. raise KeyError("foo")
  180. b = Backend()
  181. self.assertRaises(KeyError, b.get_many(["id1"]).next)
  182. def test_no_expires(self):
  183. b = self.create_backend(expires=None)
  184. app = app_or_default()
  185. prev = app.conf.CELERY_AMQP_TASK_RESULT_EXPIRES
  186. app.conf.CELERY_AMQP_TASK_RESULT_EXPIRES = None
  187. try:
  188. b = self.create_backend(expires=None)
  189. self.assertRaises(KeyError, b.queue_arguments.__getitem__,
  190. "x-expires")
  191. finally:
  192. app.conf.CELERY_AMQP_TASK_RESULT_EXPIRES = prev
  193. def test_process_cleanup(self):
  194. self.create_backend().process_cleanup()
  195. def test_reload_task_result(self):
  196. self.assertRaises(NotImplementedError,
  197. self.create_backend().reload_task_result, "x")
  198. def test_reload_taskset_result(self):
  199. self.assertRaises(NotImplementedError,
  200. self.create_backend().reload_taskset_result, "x")
  201. def test_save_taskset(self):
  202. self.assertRaises(NotImplementedError,
  203. self.create_backend().save_taskset, "x", "x")
  204. def test_restore_taskset(self):
  205. self.assertRaises(NotImplementedError,
  206. self.create_backend().restore_taskset, "x")