test_result.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588
  1. from __future__ import absolute_import
  2. from __future__ import with_statement
  3. from pickle import loads, dumps
  4. from mock import Mock
  5. from celery import states
  6. from celery.app import app_or_default
  7. from celery.exceptions import IncompleteStream
  8. from celery.utils import uuid
  9. from celery.utils.serialization import pickle
  10. from celery.result import (
  11. AsyncResult,
  12. EagerResult,
  13. GroupResult,
  14. TaskSetResult,
  15. ResultSet,
  16. from_serializable,
  17. )
  18. from celery.exceptions import TimeoutError
  19. from celery.task import task
  20. from celery.task.base import Task
  21. from celery.tests.utils import AppCase
  22. from celery.tests.utils import skip_if_quick
  23. @task()
  24. def mytask():
  25. pass
  26. def mock_task(name, state, result):
  27. return dict(id=uuid(), name=name, state=state, result=result)
  28. def save_result(task):
  29. app = app_or_default()
  30. traceback = 'Some traceback'
  31. if task['state'] == states.SUCCESS:
  32. app.backend.mark_as_done(task['id'], task['result'])
  33. elif task['state'] == states.RETRY:
  34. app.backend.mark_as_retry(
  35. task['id'], task['result'], traceback=traceback,
  36. )
  37. else:
  38. app.backend.mark_as_failure(
  39. task['id'], task['result'], traceback=traceback,
  40. )
  41. def make_mock_group(size=10):
  42. tasks = [mock_task('ts%d' % i, states.SUCCESS, i) for i in xrange(size)]
  43. [save_result(task) for task in tasks]
  44. return [AsyncResult(task['id']) for task in tasks]
  45. class test_AsyncResult(AppCase):
  46. def setup(self):
  47. self.task1 = mock_task('task1', states.SUCCESS, 'the')
  48. self.task2 = mock_task('task2', states.SUCCESS, 'quick')
  49. self.task3 = mock_task('task3', states.FAILURE, KeyError('brown'))
  50. self.task4 = mock_task('task3', states.RETRY, KeyError('red'))
  51. for task in (self.task1, self.task2, self.task3, self.task4):
  52. save_result(task)
  53. def test_compat_properties(self):
  54. x = AsyncResult('1')
  55. self.assertEqual(x.task_id, x.id)
  56. x.task_id = '2'
  57. self.assertEqual(x.id, '2')
  58. def test_children(self):
  59. x = AsyncResult('1')
  60. children = [EagerResult(str(i), i, states.SUCCESS) for i in range(3)]
  61. x.backend = Mock()
  62. x.backend.get_children.return_value = children
  63. x.backend.READY_STATES = states.READY_STATES
  64. self.assertTrue(x.children)
  65. self.assertEqual(len(x.children), 3)
  66. def test_get_children(self):
  67. tid = uuid()
  68. x = AsyncResult(tid)
  69. child = [AsyncResult(uuid()).serializable() for i in xrange(10)]
  70. x.backend._cache[tid] = {'children': child}
  71. self.assertTrue(x.children)
  72. self.assertEqual(len(x.children), 10)
  73. x.backend._cache[tid] = {'result': None}
  74. self.assertIsNone(x.children)
  75. def test_build_graph_get_leaf_collect(self):
  76. x = AsyncResult('1')
  77. x.backend._cache['1'] = {'status': states.SUCCESS, 'result': None}
  78. c = [EagerResult(str(i), i, states.SUCCESS) for i in range(3)]
  79. x.iterdeps = Mock()
  80. x.iterdeps.return_value = (
  81. (None, x),
  82. (x, c[0]),
  83. (c[0], c[1]),
  84. (c[1], c[2])
  85. )
  86. x.backend.READY_STATES = states.READY_STATES
  87. self.assertTrue(x.graph)
  88. self.assertIs(x.get_leaf(), 2)
  89. it = x.collect()
  90. self.assertListEqual(list(it), [
  91. (x, None),
  92. (c[0], 0),
  93. (c[1], 1),
  94. (c[2], 2),
  95. ])
  96. def test_iterdeps(self):
  97. x = AsyncResult('1')
  98. x.backend._cache['1'] = {'status': states.SUCCESS, 'result': None}
  99. c = [EagerResult(str(i), i, states.SUCCESS) for i in range(3)]
  100. for child in c:
  101. child.backend = Mock()
  102. child.backend.get_children.return_value = []
  103. x.backend.get_children = Mock()
  104. x.backend.get_children.return_value = c
  105. it = x.iterdeps()
  106. self.assertListEqual(list(it), [
  107. (None, x),
  108. (x, c[0]),
  109. (x, c[1]),
  110. (x, c[2]),
  111. ])
  112. x.backend._cache.pop('1')
  113. x.ready = Mock()
  114. x.ready.return_value = False
  115. with self.assertRaises(IncompleteStream):
  116. list(x.iterdeps())
  117. list(x.iterdeps(intermediate=True))
  118. def test_eq_not_implemented(self):
  119. self.assertFalse(AsyncResult('1') == object())
  120. def test_reduce(self):
  121. a1 = AsyncResult('uuid', task_name=mytask.name)
  122. restored = pickle.loads(pickle.dumps(a1))
  123. self.assertEqual(restored.id, 'uuid')
  124. self.assertEqual(restored.task_name, mytask.name)
  125. a2 = AsyncResult('uuid')
  126. self.assertEqual(pickle.loads(pickle.dumps(a2)).id, 'uuid')
  127. def test_successful(self):
  128. ok_res = AsyncResult(self.task1['id'])
  129. nok_res = AsyncResult(self.task3['id'])
  130. nok_res2 = AsyncResult(self.task4['id'])
  131. self.assertTrue(ok_res.successful())
  132. self.assertFalse(nok_res.successful())
  133. self.assertFalse(nok_res2.successful())
  134. pending_res = AsyncResult(uuid())
  135. self.assertFalse(pending_res.successful())
  136. def test_str(self):
  137. ok_res = AsyncResult(self.task1['id'])
  138. ok2_res = AsyncResult(self.task2['id'])
  139. nok_res = AsyncResult(self.task3['id'])
  140. self.assertEqual(str(ok_res), self.task1['id'])
  141. self.assertEqual(str(ok2_res), self.task2['id'])
  142. self.assertEqual(str(nok_res), self.task3['id'])
  143. pending_id = uuid()
  144. pending_res = AsyncResult(pending_id)
  145. self.assertEqual(str(pending_res), pending_id)
  146. def test_repr(self):
  147. ok_res = AsyncResult(self.task1['id'])
  148. ok2_res = AsyncResult(self.task2['id'])
  149. nok_res = AsyncResult(self.task3['id'])
  150. self.assertEqual(repr(ok_res), '<AsyncResult: %s>' % (
  151. self.task1['id']))
  152. self.assertEqual(repr(ok2_res), '<AsyncResult: %s>' % (
  153. self.task2['id']))
  154. self.assertEqual(repr(nok_res), '<AsyncResult: %s>' % (
  155. self.task3['id']))
  156. pending_id = uuid()
  157. pending_res = AsyncResult(pending_id)
  158. self.assertEqual(repr(pending_res), '<AsyncResult: %s>' % (
  159. pending_id))
  160. def test_hash(self):
  161. self.assertEqual(hash(AsyncResult('x0w991')),
  162. hash(AsyncResult('x0w991')))
  163. self.assertNotEqual(hash(AsyncResult('x0w991')),
  164. hash(AsyncResult('x1w991')))
  165. def test_get_traceback(self):
  166. ok_res = AsyncResult(self.task1['id'])
  167. nok_res = AsyncResult(self.task3['id'])
  168. nok_res2 = AsyncResult(self.task4['id'])
  169. self.assertFalse(ok_res.traceback)
  170. self.assertTrue(nok_res.traceback)
  171. self.assertTrue(nok_res2.traceback)
  172. pending_res = AsyncResult(uuid())
  173. self.assertFalse(pending_res.traceback)
  174. def test_get(self):
  175. ok_res = AsyncResult(self.task1['id'])
  176. ok2_res = AsyncResult(self.task2['id'])
  177. nok_res = AsyncResult(self.task3['id'])
  178. nok2_res = AsyncResult(self.task4['id'])
  179. self.assertEqual(ok_res.get(), 'the')
  180. self.assertEqual(ok2_res.get(), 'quick')
  181. with self.assertRaises(KeyError):
  182. nok_res.get()
  183. self.assertTrue(nok_res.get(propagate=False))
  184. self.assertIsInstance(nok2_res.result, KeyError)
  185. self.assertEqual(ok_res.info, 'the')
  186. def test_get_timeout(self):
  187. res = AsyncResult(self.task4['id']) # has RETRY state
  188. with self.assertRaises(TimeoutError):
  189. res.get(timeout=0.1)
  190. pending_res = AsyncResult(uuid())
  191. with self.assertRaises(TimeoutError):
  192. pending_res.get(timeout=0.1)
  193. @skip_if_quick
  194. def test_get_timeout_longer(self):
  195. res = AsyncResult(self.task4['id']) # has RETRY state
  196. with self.assertRaises(TimeoutError):
  197. res.get(timeout=1)
  198. def test_ready(self):
  199. oks = (AsyncResult(self.task1['id']),
  200. AsyncResult(self.task2['id']),
  201. AsyncResult(self.task3['id']))
  202. self.assertTrue(all(result.ready() for result in oks))
  203. self.assertFalse(AsyncResult(self.task4['id']).ready())
  204. self.assertFalse(AsyncResult(uuid()).ready())
  205. class test_ResultSet(AppCase):
  206. def test_resultset_repr(self):
  207. self.assertTrue(repr(ResultSet(map(AsyncResult, ['1', '2', '3']))))
  208. def test_eq_other(self):
  209. self.assertFalse(ResultSet([1, 3, 3]) == 1)
  210. self.assertTrue(ResultSet([1]) == ResultSet([1]))
  211. def test_get(self):
  212. x = ResultSet(map(AsyncResult, [1, 2, 3]))
  213. b = x.results[0].backend = Mock()
  214. b.supports_native_join = False
  215. x.join_native = Mock()
  216. x.join = Mock()
  217. x.get()
  218. self.assertTrue(x.join.called)
  219. b.supports_native_join = True
  220. x.get()
  221. self.assertTrue(x.join_native.called)
  222. def test_add(self):
  223. x = ResultSet([1])
  224. x.add(2)
  225. self.assertEqual(len(x), 2)
  226. x.add(2)
  227. self.assertEqual(len(x), 2)
  228. def test_add_discard(self):
  229. x = ResultSet([])
  230. x.add(AsyncResult('1'))
  231. self.assertIn(AsyncResult('1'), x.results)
  232. x.discard(AsyncResult('1'))
  233. x.discard(AsyncResult('1'))
  234. x.discard('1')
  235. self.assertNotIn(AsyncResult('1'), x.results)
  236. x.update([AsyncResult('2')])
  237. def test_clear(self):
  238. x = ResultSet([])
  239. r = x.results
  240. x.clear()
  241. self.assertIs(x.results, r)
  242. class MockAsyncResultFailure(AsyncResult):
  243. @property
  244. def result(self):
  245. return KeyError('baz')
  246. @property
  247. def state(self):
  248. return states.FAILURE
  249. def get(self, propagate=True, **kwargs):
  250. if propagate:
  251. raise self.result
  252. return self.result
  253. class MockAsyncResultSuccess(AsyncResult):
  254. forgotten = False
  255. def forget(self):
  256. self.forgotten = True
  257. @property
  258. def result(self):
  259. return 42
  260. @property
  261. def state(self):
  262. return states.SUCCESS
  263. def get(self, **kwargs):
  264. return self.result
  265. class SimpleBackend(object):
  266. ids = []
  267. def __init__(self, ids=[]):
  268. self.ids = ids
  269. def get_many(self, *args, **kwargs):
  270. return ((id, {'result': i, 'status': states.SUCCESS})
  271. for i, id in enumerate(self.ids))
  272. class test_TaskSetResult(AppCase):
  273. def setup(self):
  274. self.size = 10
  275. self.ts = TaskSetResult(uuid(), make_mock_group(self.size))
  276. def test_total(self):
  277. self.assertEqual(self.ts.total, self.size)
  278. def test_compat_properties(self):
  279. self.assertEqual(self.ts.taskset_id, self.ts.id)
  280. self.ts.taskset_id = 'foo'
  281. self.assertEqual(self.ts.taskset_id, 'foo')
  282. def test_compat_subtasks_kwarg(self):
  283. x = TaskSetResult(uuid(), subtasks=[1, 2, 3])
  284. self.assertEqual(x.results, [1, 2, 3])
  285. def test_itersubtasks(self):
  286. it = self.ts.itersubtasks()
  287. for i, t in enumerate(it):
  288. self.assertEqual(t.get(), i)
  289. class test_GroupResult(AppCase):
  290. def setup(self):
  291. self.size = 10
  292. self.ts = GroupResult(uuid(), make_mock_group(self.size))
  293. def test_len(self):
  294. self.assertEqual(len(self.ts), self.size)
  295. def test_eq_other(self):
  296. self.assertFalse(self.ts == 1)
  297. def test_reduce(self):
  298. self.assertTrue(loads(dumps(self.ts)))
  299. def test_iterate_raises(self):
  300. ar = MockAsyncResultFailure(uuid())
  301. ts = GroupResult(uuid(), [ar])
  302. it = iter(ts)
  303. with self.assertRaises(KeyError):
  304. it.next()
  305. def test_forget(self):
  306. subs = [MockAsyncResultSuccess(uuid()),
  307. MockAsyncResultSuccess(uuid())]
  308. ts = GroupResult(uuid(), subs)
  309. ts.forget()
  310. for sub in subs:
  311. self.assertTrue(sub.forgotten)
  312. def test_getitem(self):
  313. subs = [MockAsyncResultSuccess(uuid()),
  314. MockAsyncResultSuccess(uuid())]
  315. ts = GroupResult(uuid(), subs)
  316. self.assertIs(ts[0], subs[0])
  317. def test_save_restore(self):
  318. subs = [MockAsyncResultSuccess(uuid()),
  319. MockAsyncResultSuccess(uuid())]
  320. ts = GroupResult(uuid(), subs)
  321. ts.save()
  322. with self.assertRaises(AttributeError):
  323. ts.save(backend=object())
  324. self.assertEqual(GroupResult.restore(ts.id).subtasks,
  325. ts.subtasks)
  326. ts.delete()
  327. self.assertIsNone(GroupResult.restore(ts.id))
  328. with self.assertRaises(AttributeError):
  329. GroupResult.restore(ts.id, backend=object())
  330. def test_join_native(self):
  331. backend = SimpleBackend()
  332. subtasks = [AsyncResult(uuid(), backend=backend)
  333. for i in range(10)]
  334. ts = GroupResult(uuid(), subtasks)
  335. backend.ids = [subtask.id for subtask in subtasks]
  336. res = ts.join_native()
  337. self.assertEqual(res, range(10))
  338. def test_iter_native(self):
  339. backend = SimpleBackend()
  340. subtasks = [AsyncResult(uuid(), backend=backend)
  341. for i in range(10)]
  342. ts = GroupResult(uuid(), subtasks)
  343. backend.ids = [subtask.id for subtask in subtasks]
  344. self.assertEqual(len(list(ts.iter_native())), 10)
  345. def test_iterate_yields(self):
  346. ar = MockAsyncResultSuccess(uuid())
  347. ar2 = MockAsyncResultSuccess(uuid())
  348. ts = GroupResult(uuid(), [ar, ar2])
  349. it = iter(ts)
  350. self.assertEqual(it.next(), 42)
  351. self.assertEqual(it.next(), 42)
  352. def test_iterate_eager(self):
  353. ar1 = EagerResult(uuid(), 42, states.SUCCESS)
  354. ar2 = EagerResult(uuid(), 42, states.SUCCESS)
  355. ts = GroupResult(uuid(), [ar1, ar2])
  356. it = iter(ts)
  357. self.assertEqual(it.next(), 42)
  358. self.assertEqual(it.next(), 42)
  359. def test_join_timeout(self):
  360. ar = MockAsyncResultSuccess(uuid())
  361. ar2 = MockAsyncResultSuccess(uuid())
  362. ar3 = AsyncResult(uuid())
  363. ts = GroupResult(uuid(), [ar, ar2, ar3])
  364. with self.assertRaises(TimeoutError):
  365. ts.join(timeout=0.0000001)
  366. def test___iter__(self):
  367. it = iter(self.ts)
  368. results = sorted(list(it))
  369. self.assertListEqual(results, list(xrange(self.size)))
  370. def test_join(self):
  371. joined = self.ts.join()
  372. self.assertListEqual(joined, list(xrange(self.size)))
  373. def test_successful(self):
  374. self.assertTrue(self.ts.successful())
  375. def test_failed(self):
  376. self.assertFalse(self.ts.failed())
  377. def test_waiting(self):
  378. self.assertFalse(self.ts.waiting())
  379. def test_ready(self):
  380. self.assertTrue(self.ts.ready())
  381. def test_completed_count(self):
  382. self.assertEqual(self.ts.completed_count(), len(self.ts))
  383. class test_pending_AsyncResult(AppCase):
  384. def setup(self):
  385. self.task = AsyncResult(uuid())
  386. def test_result(self):
  387. self.assertIsNone(self.task.result)
  388. class test_failed_AsyncResult(test_GroupResult):
  389. def setup(self):
  390. self.size = 11
  391. subtasks = make_mock_group(10)
  392. failed = mock_task('ts11', states.FAILURE, KeyError('Baz'))
  393. save_result(failed)
  394. failed_res = AsyncResult(failed['id'])
  395. self.ts = GroupResult(uuid(), subtasks + [failed_res])
  396. def test_completed_count(self):
  397. self.assertEqual(self.ts.completed_count(), len(self.ts) - 1)
  398. def test___iter__(self):
  399. it = iter(self.ts)
  400. def consume():
  401. return list(it)
  402. with self.assertRaises(KeyError):
  403. consume()
  404. def test_join(self):
  405. with self.assertRaises(KeyError):
  406. self.ts.join()
  407. def test_successful(self):
  408. self.assertFalse(self.ts.successful())
  409. def test_failed(self):
  410. self.assertTrue(self.ts.failed())
  411. class test_pending_Group(AppCase):
  412. def setup(self):
  413. self.ts = GroupResult(uuid(), [AsyncResult(uuid()),
  414. AsyncResult(uuid())])
  415. def test_completed_count(self):
  416. self.assertEqual(self.ts.completed_count(), 0)
  417. def test_ready(self):
  418. self.assertFalse(self.ts.ready())
  419. def test_waiting(self):
  420. self.assertTrue(self.ts.waiting())
  421. def x_join(self):
  422. with self.assertRaises(TimeoutError):
  423. self.ts.join(timeout=0.001)
  424. @skip_if_quick
  425. def x_join_longer(self):
  426. with self.assertRaises(TimeoutError):
  427. self.ts.join(timeout=1)
  428. class RaisingTask(Task):
  429. def run(self, x, y):
  430. raise KeyError('xy')
  431. class test_EagerResult(AppCase):
  432. def test_wait_raises(self):
  433. res = RaisingTask.apply(args=[3, 3])
  434. with self.assertRaises(KeyError):
  435. res.wait()
  436. self.assertTrue(res.wait(propagate=False))
  437. def test_wait(self):
  438. res = EagerResult('x', 'x', states.RETRY)
  439. res.wait()
  440. self.assertEqual(res.state, states.RETRY)
  441. self.assertEqual(res.status, states.RETRY)
  442. def test_forget(self):
  443. res = EagerResult('x', 'x', states.RETRY)
  444. res.forget()
  445. def test_revoke(self):
  446. res = RaisingTask.apply(args=[3, 3])
  447. self.assertFalse(res.revoke())
  448. class test_serializable(AppCase):
  449. def test_AsyncResult(self):
  450. x = AsyncResult(uuid())
  451. self.assertEqual(x, from_serializable(x.serializable()))
  452. self.assertEqual(x, from_serializable(x))
  453. def test_GroupResult(self):
  454. x = GroupResult(uuid(), [AsyncResult(uuid()) for _ in range(10)])
  455. self.assertEqual(x, from_serializable(x.serializable()))
  456. self.assertEqual(x, from_serializable(x))