test_result.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421
  1. from celery import states
  2. from celery.app import app_or_default
  3. from celery.utils import uuid
  4. from celery.utils.serialization import pickle
  5. from celery.result import AsyncResult, EagerResult, TaskSetResult, ResultSet
  6. from celery.exceptions import TimeoutError
  7. from celery.task.base import Task
  8. from celery.tests.utils import unittest
  9. from celery.tests.utils import skip_if_quick
  10. def mock_task(name, status, result):
  11. return dict(id=uuid(), name=name, status=status, result=result)
  12. def save_result(task):
  13. app = app_or_default()
  14. traceback = "Some traceback"
  15. if task["status"] == states.SUCCESS:
  16. app.backend.mark_as_done(task["id"], task["result"])
  17. elif task["status"] == states.RETRY:
  18. app.backend.mark_as_retry(task["id"], task["result"],
  19. traceback=traceback)
  20. else:
  21. app.backend.mark_as_failure(task["id"], task["result"],
  22. traceback=traceback)
  23. def make_mock_taskset(size=10):
  24. tasks = [mock_task("ts%d" % i, states.SUCCESS, i) for i in xrange(size)]
  25. [save_result(task) for task in tasks]
  26. return [AsyncResult(task["id"]) for task in tasks]
  27. class TestAsyncResult(unittest.TestCase):
  28. def setUp(self):
  29. self.task1 = mock_task("task1", states.SUCCESS, "the")
  30. self.task2 = mock_task("task2", states.SUCCESS, "quick")
  31. self.task3 = mock_task("task3", states.FAILURE, KeyError("brown"))
  32. self.task4 = mock_task("task3", states.RETRY, KeyError("red"))
  33. for task in (self.task1, self.task2, self.task3, self.task4):
  34. save_result(task)
  35. def test_reduce(self):
  36. a1 = AsyncResult("uuid", task_name="celery.ping")
  37. restored = pickle.loads(pickle.dumps(a1))
  38. self.assertEqual(restored.task_id, "uuid")
  39. self.assertEqual(restored.task_name, "celery.ping")
  40. a2 = AsyncResult("uuid")
  41. self.assertEqual(pickle.loads(pickle.dumps(a2)).task_id, "uuid")
  42. def test_successful(self):
  43. ok_res = AsyncResult(self.task1["id"])
  44. nok_res = AsyncResult(self.task3["id"])
  45. nok_res2 = AsyncResult(self.task4["id"])
  46. self.assertTrue(ok_res.successful())
  47. self.assertFalse(nok_res.successful())
  48. self.assertFalse(nok_res2.successful())
  49. pending_res = AsyncResult(uuid())
  50. self.assertFalse(pending_res.successful())
  51. def test_str(self):
  52. ok_res = AsyncResult(self.task1["id"])
  53. ok2_res = AsyncResult(self.task2["id"])
  54. nok_res = AsyncResult(self.task3["id"])
  55. self.assertEqual(str(ok_res), self.task1["id"])
  56. self.assertEqual(str(ok2_res), self.task2["id"])
  57. self.assertEqual(str(nok_res), self.task3["id"])
  58. pending_id = uuid()
  59. pending_res = AsyncResult(pending_id)
  60. self.assertEqual(str(pending_res), pending_id)
  61. def test_repr(self):
  62. ok_res = AsyncResult(self.task1["id"])
  63. ok2_res = AsyncResult(self.task2["id"])
  64. nok_res = AsyncResult(self.task3["id"])
  65. self.assertEqual(repr(ok_res), "<AsyncResult: %s>" % (
  66. self.task1["id"]))
  67. self.assertEqual(repr(ok2_res), "<AsyncResult: %s>" % (
  68. self.task2["id"]))
  69. self.assertEqual(repr(nok_res), "<AsyncResult: %s>" % (
  70. self.task3["id"]))
  71. pending_id = uuid()
  72. pending_res = AsyncResult(pending_id)
  73. self.assertEqual(repr(pending_res), "<AsyncResult: %s>" % (
  74. pending_id))
  75. def test_hash(self):
  76. self.assertEqual(hash(AsyncResult("x0w991")),
  77. hash(AsyncResult("x0w991")))
  78. self.assertNotEqual(hash(AsyncResult("x0w991")),
  79. hash(AsyncResult("x1w991")))
  80. def test_get_traceback(self):
  81. ok_res = AsyncResult(self.task1["id"])
  82. nok_res = AsyncResult(self.task3["id"])
  83. nok_res2 = AsyncResult(self.task4["id"])
  84. self.assertFalse(ok_res.traceback)
  85. self.assertTrue(nok_res.traceback)
  86. self.assertTrue(nok_res2.traceback)
  87. pending_res = AsyncResult(uuid())
  88. self.assertFalse(pending_res.traceback)
  89. def test_get(self):
  90. ok_res = AsyncResult(self.task1["id"])
  91. ok2_res = AsyncResult(self.task2["id"])
  92. nok_res = AsyncResult(self.task3["id"])
  93. nok2_res = AsyncResult(self.task4["id"])
  94. self.assertEqual(ok_res.get(), "the")
  95. self.assertEqual(ok2_res.get(), "quick")
  96. self.assertRaises(KeyError, nok_res.get)
  97. self.assertIsInstance(nok2_res.result, KeyError)
  98. self.assertEqual(ok_res.info, "the")
  99. def test_get_timeout(self):
  100. res = AsyncResult(self.task4["id"]) # has RETRY status
  101. self.assertRaises(TimeoutError, res.get, timeout=0.1)
  102. pending_res = AsyncResult(uuid())
  103. self.assertRaises(TimeoutError, pending_res.get, timeout=0.1)
  104. @skip_if_quick
  105. def test_get_timeout_longer(self):
  106. res = AsyncResult(self.task4["id"]) # has RETRY status
  107. self.assertRaises(TimeoutError, res.get, timeout=1)
  108. def test_ready(self):
  109. oks = (AsyncResult(self.task1["id"]),
  110. AsyncResult(self.task2["id"]),
  111. AsyncResult(self.task3["id"]))
  112. self.assertTrue(all(result.ready() for result in oks))
  113. self.assertFalse(AsyncResult(self.task4["id"]).ready())
  114. self.assertFalse(AsyncResult(uuid()).ready())
  115. class test_ResultSet(unittest.TestCase):
  116. def test_add_discard(self):
  117. x = ResultSet([])
  118. x.add(AsyncResult("1"))
  119. self.assertIn(AsyncResult("1"), x.results)
  120. x.discard(AsyncResult("1"))
  121. x.discard(AsyncResult("1"))
  122. x.discard("1")
  123. self.assertNotIn(AsyncResult("1"), x.results)
  124. x.update([AsyncResult("2")])
  125. def test_clear(self):
  126. x = ResultSet([])
  127. r = x.results
  128. x.clear()
  129. self.assertIs(x.results, r)
  130. class MockAsyncResultFailure(AsyncResult):
  131. @property
  132. def result(self):
  133. return KeyError("baz")
  134. @property
  135. def status(self):
  136. return states.FAILURE
  137. def get(self, propagate=True, **kwargs):
  138. if propagate:
  139. raise self.result
  140. return self.result
  141. class MockAsyncResultSuccess(AsyncResult):
  142. forgotten = False
  143. def forget(self):
  144. self.forgotten = True
  145. @property
  146. def result(self):
  147. return 42
  148. @property
  149. def status(self):
  150. return states.SUCCESS
  151. def get(self, **kwargs):
  152. return self.result
  153. class SimpleBackend(object):
  154. ids = []
  155. def __init__(self, ids=[]):
  156. self.ids = ids
  157. def get_many(self, *args, **kwargs):
  158. return ((id, {"result": i}) for i, id in enumerate(self.ids))
  159. class TestTaskSetResult(unittest.TestCase):
  160. def setUp(self):
  161. self.size = 10
  162. self.ts = TaskSetResult(uuid(), make_mock_taskset(self.size))
  163. def test_total(self):
  164. self.assertEqual(self.ts.total, self.size)
  165. def test_iterate_raises(self):
  166. ar = MockAsyncResultFailure(uuid())
  167. ts = TaskSetResult(uuid(), [ar])
  168. it = iter(ts)
  169. self.assertRaises(KeyError, it.next)
  170. def test_forget(self):
  171. subs = [MockAsyncResultSuccess(uuid()),
  172. MockAsyncResultSuccess(uuid())]
  173. ts = TaskSetResult(uuid(), subs)
  174. ts.forget()
  175. for sub in subs:
  176. self.assertTrue(sub.forgotten)
  177. def test_getitem(self):
  178. subs = [MockAsyncResultSuccess(uuid()),
  179. MockAsyncResultSuccess(uuid())]
  180. ts = TaskSetResult(uuid(), subs)
  181. self.assertIs(ts[0], subs[0])
  182. def test_save_restore(self):
  183. subs = [MockAsyncResultSuccess(uuid()),
  184. MockAsyncResultSuccess(uuid())]
  185. ts = TaskSetResult(uuid(), subs)
  186. ts.save()
  187. self.assertRaises(AttributeError, ts.save, backend=object())
  188. self.assertEqual(TaskSetResult.restore(ts.taskset_id).subtasks,
  189. ts.subtasks)
  190. ts.delete()
  191. self.assertIsNone(TaskSetResult.restore(ts.taskset_id))
  192. self.assertRaises(AttributeError,
  193. TaskSetResult.restore, ts.taskset_id,
  194. backend=object())
  195. def test_join_native(self):
  196. backend = SimpleBackend()
  197. subtasks = [AsyncResult(uuid(), backend=backend)
  198. for i in range(10)]
  199. ts = TaskSetResult(uuid(), subtasks)
  200. backend.ids = [subtask.task_id for subtask in subtasks]
  201. res = ts.join_native()
  202. self.assertEqual(res, range(10))
  203. def test_iter_native(self):
  204. backend = SimpleBackend()
  205. subtasks = [AsyncResult(uuid(), backend=backend)
  206. for i in range(10)]
  207. ts = TaskSetResult(uuid(), subtasks)
  208. backend.ids = [subtask.task_id for subtask in subtasks]
  209. self.assertEqual(len(list(ts.iter_native())), 10)
  210. def test_iterate_yields(self):
  211. ar = MockAsyncResultSuccess(uuid())
  212. ar2 = MockAsyncResultSuccess(uuid())
  213. ts = TaskSetResult(uuid(), [ar, ar2])
  214. it = iter(ts)
  215. self.assertEqual(it.next(), 42)
  216. self.assertEqual(it.next(), 42)
  217. def test_iterate_eager(self):
  218. ar1 = EagerResult(uuid(), 42, states.SUCCESS)
  219. ar2 = EagerResult(uuid(), 42, states.SUCCESS)
  220. ts = TaskSetResult(uuid(), [ar1, ar2])
  221. it = iter(ts)
  222. self.assertEqual(it.next(), 42)
  223. self.assertEqual(it.next(), 42)
  224. def test_join_timeout(self):
  225. ar = MockAsyncResultSuccess(uuid())
  226. ar2 = MockAsyncResultSuccess(uuid())
  227. ar3 = AsyncResult(uuid())
  228. ts = TaskSetResult(uuid(), [ar, ar2, ar3])
  229. self.assertRaises(TimeoutError, ts.join, timeout=0.0000001)
  230. def test_itersubtasks(self):
  231. it = self.ts.itersubtasks()
  232. for i, t in enumerate(it):
  233. self.assertEqual(t.get(), i)
  234. def test___iter__(self):
  235. it = iter(self.ts)
  236. results = sorted(list(it))
  237. self.assertListEqual(results, list(xrange(self.size)))
  238. def test_join(self):
  239. joined = self.ts.join()
  240. self.assertListEqual(joined, list(xrange(self.size)))
  241. def test_successful(self):
  242. self.assertTrue(self.ts.successful())
  243. def test_failed(self):
  244. self.assertFalse(self.ts.failed())
  245. def test_waiting(self):
  246. self.assertFalse(self.ts.waiting())
  247. def test_ready(self):
  248. self.assertTrue(self.ts.ready())
  249. def test_completed_count(self):
  250. self.assertEqual(self.ts.completed_count(), self.ts.total)
  251. class TestPendingAsyncResult(unittest.TestCase):
  252. def setUp(self):
  253. self.task = AsyncResult(uuid())
  254. def test_result(self):
  255. self.assertIsNone(self.task.result)
  256. class TestFailedTaskSetResult(TestTaskSetResult):
  257. def setUp(self):
  258. self.size = 11
  259. subtasks = make_mock_taskset(10)
  260. failed = mock_task("ts11", states.FAILURE, KeyError("Baz"))
  261. save_result(failed)
  262. failed_res = AsyncResult(failed["id"])
  263. self.ts = TaskSetResult(uuid(), subtasks + [failed_res])
  264. def test_itersubtasks(self):
  265. it = self.ts.itersubtasks()
  266. for i in xrange(self.size - 1):
  267. t = it.next()
  268. self.assertEqual(t.get(), i)
  269. self.assertRaises(KeyError, it.next().get)
  270. def test_completed_count(self):
  271. self.assertEqual(self.ts.completed_count(), self.ts.total - 1)
  272. def test___iter__(self):
  273. it = iter(self.ts)
  274. def consume():
  275. return list(it)
  276. self.assertRaises(KeyError, consume)
  277. def test_join(self):
  278. self.assertRaises(KeyError, self.ts.join)
  279. def test_successful(self):
  280. self.assertFalse(self.ts.successful())
  281. def test_failed(self):
  282. self.assertTrue(self.ts.failed())
  283. class TestTaskSetPending(unittest.TestCase):
  284. def setUp(self):
  285. self.ts = TaskSetResult(uuid(), [
  286. AsyncResult(uuid()),
  287. AsyncResult(uuid())])
  288. def test_completed_count(self):
  289. self.assertEqual(self.ts.completed_count(), 0)
  290. def test_ready(self):
  291. self.assertFalse(self.ts.ready())
  292. def test_waiting(self):
  293. self.assertTrue(self.ts.waiting())
  294. def x_join(self):
  295. self.assertRaises(TimeoutError, self.ts.join, timeout=0.001)
  296. @skip_if_quick
  297. def x_join_longer(self):
  298. self.assertRaises(TimeoutError, self.ts.join, timeout=1)
  299. class RaisingTask(Task):
  300. def run(self, x, y):
  301. raise KeyError("xy")
  302. class TestEagerResult(unittest.TestCase):
  303. def test_wait_raises(self):
  304. res = RaisingTask.apply(args=[3, 3])
  305. self.assertRaises(KeyError, res.wait)
  306. def test_wait(self):
  307. res = EagerResult("x", "x", states.RETRY)
  308. res.wait()
  309. self.assertEqual(res.state, states.RETRY)
  310. self.assertEqual(res.status, states.RETRY)
  311. def test_revoke(self):
  312. res = RaisingTask.apply(args=[3, 3])
  313. self.assertFalse(res.revoke())