test_canvas.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616
  1. from __future__ import absolute_import, unicode_literals
  2. from datetime import datetime, timedelta
  3. import pytest
  4. from celery import chain, chord, group
  5. from celery.exceptions import TimeoutError
  6. from celery.result import AsyncResult, GroupResult, ResultSet
  7. from .conftest import flaky, get_active_redis_channels, get_redis_connection
  8. from .tasks import (add, add_chord_to_chord, add_replaced, add_to_all,
  9. add_to_all_to_chord, build_chain_inside_task, chord_error,
  10. collect_ids, delayed_sum, delayed_sum_with_soft_guard,
  11. fail, identity, ids, print_unicode, raise_error,
  12. redis_echo, second_order_replace1, tsum)
  13. TIMEOUT = 120
  14. class test_chain:
  15. @flaky
  16. def test_simple_chain(self, manager):
  17. c = add.s(4, 4) | add.s(8) | add.s(16)
  18. assert c().get(timeout=TIMEOUT) == 32
  19. @flaky
  20. def test_single_chain(self, manager):
  21. c = chain(add.s(3, 4))()
  22. assert c.get(timeout=TIMEOUT) == 7
  23. @flaky
  24. def test_complex_chain(self, manager):
  25. c = (
  26. add.s(2, 2) | (
  27. add.s(4) | add_replaced.s(8) | add.s(16) | add.s(32)
  28. ) |
  29. group(add.s(i) for i in range(4))
  30. )
  31. res = c()
  32. assert res.get(timeout=TIMEOUT) == [64, 65, 66, 67]
  33. @flaky
  34. def test_group_results_in_chain(self, manager):
  35. # This adds in an explicit test for the special case added in commit
  36. # 1e3fcaa969de6ad32b52a3ed8e74281e5e5360e6
  37. c = (
  38. group(
  39. add.s(1, 2) | group(
  40. add.s(1), add.s(2)
  41. )
  42. )
  43. )
  44. res = c()
  45. assert res.get(timeout=TIMEOUT) == [4, 5]
  46. @flaky
  47. def test_chain_inside_group_receives_arguments(self, manager):
  48. c = (
  49. add.s(5, 6) |
  50. group((add.s(1) | add.s(2), add.s(3)))
  51. )
  52. res = c()
  53. assert res.get(timeout=TIMEOUT) == [14, 14]
  54. @flaky
  55. def test_eager_chain_inside_task(self, manager):
  56. from .tasks import chain_add
  57. prev = chain_add.app.conf.task_always_eager
  58. chain_add.app.conf.task_always_eager = True
  59. chain_add.apply_async(args=(4, 8), throw=True).get()
  60. chain_add.app.conf.task_always_eager = prev
  61. @flaky
  62. def test_group_chord_group_chain(self, manager):
  63. from celery.five import bytes_if_py2
  64. if not manager.app.conf.result_backend.startswith('redis'):
  65. raise pytest.skip('Requires redis result backend.')
  66. redis_connection = get_redis_connection()
  67. redis_connection.delete('redis-echo')
  68. before = group(redis_echo.si('before {}'.format(i)) for i in range(3))
  69. connect = redis_echo.si('connect')
  70. after = group(redis_echo.si('after {}'.format(i)) for i in range(2))
  71. result = (before | connect | after).delay()
  72. result.get(timeout=TIMEOUT)
  73. redis_messages = list(map(
  74. bytes_if_py2,
  75. redis_connection.lrange('redis-echo', 0, -1)
  76. ))
  77. before_items = \
  78. set(map(bytes_if_py2, (b'before 0', b'before 1', b'before 2')))
  79. after_items = set(map(bytes_if_py2, (b'after 0', b'after 1')))
  80. assert set(redis_messages[:3]) == before_items
  81. assert redis_messages[3] == b'connect'
  82. assert set(redis_messages[4:]) == after_items
  83. redis_connection.delete('redis-echo')
  84. @flaky
  85. def test_second_order_replace(self, manager):
  86. from celery.five import bytes_if_py2
  87. if not manager.app.conf.result_backend.startswith('redis'):
  88. raise pytest.skip('Requires redis result backend.')
  89. redis_connection = get_redis_connection()
  90. redis_connection.delete('redis-echo')
  91. result = second_order_replace1.delay()
  92. result.get(timeout=TIMEOUT)
  93. redis_messages = list(map(
  94. bytes_if_py2,
  95. redis_connection.lrange('redis-echo', 0, -1)
  96. ))
  97. expected_messages = [b'In A', b'In B', b'In/Out C', b'Out B', b'Out A']
  98. assert redis_messages == expected_messages
  99. @flaky
  100. def test_parent_ids(self, manager, num=10):
  101. assert manager.inspect().ping()
  102. c = chain(ids.si(i=i) for i in range(num))
  103. c.freeze()
  104. res = c()
  105. try:
  106. res.get(timeout=TIMEOUT)
  107. except TimeoutError:
  108. print(manager.inspect.active())
  109. print(manager.inspect.reserved())
  110. print(manager.inspect.stats())
  111. raise
  112. self.assert_ids(res, num - 1)
  113. def assert_ids(self, res, size):
  114. i, root = size, res
  115. while root.parent:
  116. root = root.parent
  117. node = res
  118. while node:
  119. root_id, parent_id, value = node.get(timeout=30)
  120. assert value == i
  121. if node.parent:
  122. assert parent_id == node.parent.id
  123. assert root_id == root.id
  124. node = node.parent
  125. i -= 1
  126. def test_chord_soft_timeout_recuperation(self, manager):
  127. """Test that if soft timeout happens in task but is managed by task,
  128. chord still get results normally
  129. """
  130. if not manager.app.conf.result_backend.startswith('redis'):
  131. raise pytest.skip('Requires redis result backend.')
  132. c = chord([
  133. # return 3
  134. add.s(1, 2),
  135. # return 0 after managing soft timeout
  136. delayed_sum_with_soft_guard.s(
  137. [100], pause_time=2
  138. ).set(
  139. soft_time_limit=1
  140. ),
  141. ])
  142. result = c(delayed_sum.s(pause_time=0)).get()
  143. assert result == 3
  144. def test_chain_error_handler_with_eta(self, manager):
  145. try:
  146. manager.app.backend.ensure_chords_allowed()
  147. except NotImplementedError as e:
  148. raise pytest.skip(e.args[0])
  149. eta = datetime.utcnow() + timedelta(seconds=10)
  150. c = chain(
  151. group(
  152. add.s(1, 2),
  153. add.s(3, 4),
  154. ),
  155. tsum.s()
  156. ).on_error(print_unicode.s()).apply_async(eta=eta)
  157. result = c.get()
  158. assert result == 10
  159. @flaky
  160. def test_groupresult_serialization(self, manager):
  161. """Test GroupResult is correctly serialized
  162. to save in the result backend"""
  163. try:
  164. manager.app.backend.ensure_chords_allowed()
  165. except NotImplementedError as e:
  166. raise pytest.skip(e.args[0])
  167. async_result = build_chain_inside_task.delay()
  168. result = async_result.get()
  169. assert len(result) == 2
  170. assert isinstance(result[0][1], list)
  171. class test_result_set:
  172. @flaky
  173. def test_result_set(self, manager):
  174. assert manager.inspect().ping()
  175. rs = ResultSet([add.delay(1, 1), add.delay(2, 2)])
  176. assert rs.get(timeout=TIMEOUT) == [2, 4]
  177. @flaky
  178. def test_result_set_error(self, manager):
  179. assert manager.inspect().ping()
  180. rs = ResultSet([raise_error.delay(), add.delay(1, 1)])
  181. rs.get(timeout=TIMEOUT, propagate=False)
  182. assert rs.results[0].failed()
  183. assert rs.results[1].successful()
  184. class test_group:
  185. @flaky
  186. def test_empty_group_result(self, manager):
  187. if not manager.app.conf.result_backend.startswith('redis'):
  188. raise pytest.skip('Requires redis result backend.')
  189. task = group([])
  190. result = task.apply_async()
  191. GroupResult.save(result)
  192. task = GroupResult.restore(result.id)
  193. assert task.results == []
  194. @flaky
  195. def test_parent_ids(self, manager):
  196. assert manager.inspect().ping()
  197. g = (
  198. ids.si(i=1) |
  199. ids.si(i=2) |
  200. group(ids.si(i=i) for i in range(2, 50))
  201. )
  202. res = g()
  203. expected_root_id = res.parent.parent.id
  204. expected_parent_id = res.parent.id
  205. values = res.get(timeout=TIMEOUT)
  206. for i, r in enumerate(values):
  207. root_id, parent_id, value = r
  208. assert root_id == expected_root_id
  209. assert parent_id == expected_parent_id
  210. assert value == i + 2
  211. @flaky
  212. def test_nested_group(self, manager):
  213. assert manager.inspect().ping()
  214. c = group(
  215. add.si(1, 10),
  216. group(
  217. add.si(1, 100),
  218. group(
  219. add.si(1, 1000),
  220. add.si(1, 2000),
  221. ),
  222. ),
  223. )
  224. res = c()
  225. assert res.get(timeout=TIMEOUT) == [11, 101, 1001, 2001]
  226. def assert_ids(r, expected_value, expected_root_id, expected_parent_id):
  227. root_id, parent_id, value = r.get(timeout=TIMEOUT)
  228. assert expected_value == value
  229. assert root_id == expected_root_id
  230. assert parent_id == expected_parent_id
  231. class test_chord:
  232. @flaky
  233. def test_redis_subscribed_channels_leak(self, manager):
  234. if not manager.app.conf.result_backend.startswith('redis'):
  235. raise pytest.skip('Requires redis result backend.')
  236. manager.app.backend.result_consumer.on_after_fork()
  237. initial_channels = get_active_redis_channels()
  238. initial_channels_count = len(initial_channels)
  239. total_chords = 10
  240. async_results = [
  241. chord([add.s(5, 6), add.s(6, 7)])(delayed_sum.s())
  242. for _ in range(total_chords)
  243. ]
  244. manager.assert_result_tasks_in_progress_or_completed(async_results)
  245. channels_before = get_active_redis_channels()
  246. channels_before_count = len(channels_before)
  247. assert set(channels_before) != set(initial_channels)
  248. assert channels_before_count > initial_channels_count
  249. # The total number of active Redis channels at this point
  250. # is the number of chord header tasks multiplied by the
  251. # total chord tasks, plus the initial channels
  252. # (existing from previous tests).
  253. chord_header_task_count = 2
  254. assert channels_before_count <= \
  255. chord_header_task_count * total_chords + initial_channels_count
  256. result_values = [
  257. result.get(timeout=TIMEOUT)
  258. for result in async_results
  259. ]
  260. assert result_values == [24] * total_chords
  261. channels_after = get_active_redis_channels()
  262. channels_after_count = len(channels_after)
  263. assert channels_after_count == initial_channels_count
  264. assert set(channels_after) == set(initial_channels)
  265. @flaky
  266. def test_replaced_nested_chord(self, manager):
  267. try:
  268. manager.app.backend.ensure_chords_allowed()
  269. except NotImplementedError as e:
  270. raise pytest.skip(e.args[0])
  271. c1 = chord([
  272. chord(
  273. [add.s(1, 2), add_replaced.s(3, 4)],
  274. add_to_all.s(5),
  275. ) | tsum.s(),
  276. chord(
  277. [add_replaced.s(6, 7), add.s(0, 0)],
  278. add_to_all.s(8),
  279. ) | tsum.s(),
  280. ], add_to_all.s(9))
  281. res1 = c1()
  282. assert res1.get(timeout=TIMEOUT) == [29, 38]
  283. @flaky
  284. def test_add_to_chord(self, manager):
  285. if not manager.app.conf.result_backend.startswith('redis'):
  286. raise pytest.skip('Requires redis result backend.')
  287. c = group([add_to_all_to_chord.s([1, 2, 3], 4)]) | identity.s()
  288. res = c()
  289. assert res.get() == [0, 5, 6, 7]
  290. @flaky
  291. def test_add_chord_to_chord(self, manager):
  292. if not manager.app.conf.result_backend.startswith('redis'):
  293. raise pytest.skip('Requires redis result backend.')
  294. c = group([add_chord_to_chord.s([1, 2, 3], 4)]) | identity.s()
  295. res = c()
  296. assert res.get() == [0, 5 + 6 + 7]
  297. @flaky
  298. def test_eager_chord_inside_task(self, manager):
  299. from .tasks import chord_add
  300. prev = chord_add.app.conf.task_always_eager
  301. chord_add.app.conf.task_always_eager = True
  302. chord_add.apply_async(args=(4, 8), throw=True).get()
  303. chord_add.app.conf.task_always_eager = prev
  304. @flaky
  305. def test_group_chain(self, manager):
  306. if not manager.app.conf.result_backend.startswith('redis'):
  307. raise pytest.skip('Requires redis result backend.')
  308. c = (
  309. add.s(2, 2) |
  310. group(add.s(i) for i in range(4)) |
  311. add_to_all.s(8)
  312. )
  313. res = c()
  314. assert res.get(timeout=TIMEOUT) == [12, 13, 14, 15]
  315. @flaky
  316. def test_nested_group_chain(self, manager):
  317. try:
  318. manager.app.backend.ensure_chords_allowed()
  319. except NotImplementedError as e:
  320. raise pytest.skip(e.args[0])
  321. if not manager.app.backend.supports_native_join:
  322. raise pytest.skip('Requires native join support.')
  323. c = chain(
  324. add.si(1, 0),
  325. group(
  326. add.si(1, 100),
  327. chain(
  328. add.si(1, 200),
  329. group(
  330. add.si(1, 1000),
  331. add.si(1, 2000),
  332. ),
  333. ),
  334. ),
  335. add.si(1, 10),
  336. )
  337. res = c()
  338. assert res.get(timeout=TIMEOUT) == 11
  339. @flaky
  340. def test_single_task_header(self, manager):
  341. try:
  342. manager.app.backend.ensure_chords_allowed()
  343. except NotImplementedError as e:
  344. raise pytest.skip(e.args[0])
  345. c1 = chord([add.s(2, 5)], body=add_to_all.s(9))
  346. res1 = c1()
  347. assert res1.get(timeout=TIMEOUT) == [16]
  348. c2 = group([add.s(2, 5)]) | add_to_all.s(9)
  349. res2 = c2()
  350. assert res2.get(timeout=TIMEOUT) == [16]
  351. def test_empty_header_chord(self, manager):
  352. try:
  353. manager.app.backend.ensure_chords_allowed()
  354. except NotImplementedError as e:
  355. raise pytest.skip(e.args[0])
  356. c1 = chord([], body=add_to_all.s(9))
  357. res1 = c1()
  358. assert res1.get(timeout=TIMEOUT) == []
  359. c2 = group([]) | add_to_all.s(9)
  360. res2 = c2()
  361. assert res2.get(timeout=TIMEOUT) == []
  362. @flaky
  363. def test_nested_chord(self, manager):
  364. try:
  365. manager.app.backend.ensure_chords_allowed()
  366. except NotImplementedError as e:
  367. raise pytest.skip(e.args[0])
  368. c1 = chord([
  369. chord([add.s(1, 2), add.s(3, 4)], add.s([5])),
  370. chord([add.s(6, 7)], add.s([10]))
  371. ], add_to_all.s(['A']))
  372. res1 = c1()
  373. assert res1.get(timeout=TIMEOUT) == [[3, 7, 5, 'A'], [13, 10, 'A']]
  374. c2 = group([
  375. group([add.s(1, 2), add.s(3, 4)]) | add.s([5]),
  376. group([add.s(6, 7)]) | add.s([10]),
  377. ]) | add_to_all.s(['A'])
  378. res2 = c2()
  379. assert res2.get(timeout=TIMEOUT) == [[3, 7, 5, 'A'], [13, 10, 'A']]
  380. c = group([
  381. group([
  382. group([
  383. group([
  384. add.s(1, 2)
  385. ]) | add.s([3])
  386. ]) | add.s([4])
  387. ]) | add.s([5])
  388. ]) | add.s([6])
  389. res = c()
  390. assert [[[[3, 3], 4], 5], 6] == res.get(timeout=TIMEOUT)
  391. @flaky
  392. def test_parent_ids(self, manager):
  393. if not manager.app.conf.result_backend.startswith('redis'):
  394. raise pytest.skip('Requires redis result backend.')
  395. root = ids.si(i=1)
  396. expected_root_id = root.freeze().id
  397. g = chain(
  398. root, ids.si(i=2),
  399. chord(
  400. group(ids.si(i=i) for i in range(3, 50)),
  401. chain(collect_ids.s(i=50) | ids.si(i=51)),
  402. ),
  403. )
  404. self.assert_parentids_chord(g(), expected_root_id)
  405. @flaky
  406. def test_parent_ids__OR(self, manager):
  407. if not manager.app.conf.result_backend.startswith('redis'):
  408. raise pytest.skip('Requires redis result backend.')
  409. root = ids.si(i=1)
  410. expected_root_id = root.freeze().id
  411. g = (
  412. root |
  413. ids.si(i=2) |
  414. group(ids.si(i=i) for i in range(3, 50)) |
  415. collect_ids.s(i=50) |
  416. ids.si(i=51)
  417. )
  418. self.assert_parentids_chord(g(), expected_root_id)
  419. def assert_parentids_chord(self, res, expected_root_id):
  420. assert isinstance(res, AsyncResult)
  421. assert isinstance(res.parent, AsyncResult)
  422. assert isinstance(res.parent.parent, GroupResult)
  423. assert isinstance(res.parent.parent.parent, AsyncResult)
  424. assert isinstance(res.parent.parent.parent.parent, AsyncResult)
  425. # first we check the last task
  426. assert_ids(res, 51, expected_root_id, res.parent.id)
  427. # then the chord callback
  428. prev, (root_id, parent_id, value) = res.parent.get(timeout=30)
  429. assert value == 50
  430. assert root_id == expected_root_id
  431. # started by one of the chord header tasks.
  432. assert parent_id in res.parent.parent.results
  433. # check what the chord callback recorded
  434. for i, p in enumerate(prev):
  435. root_id, parent_id, value = p
  436. assert root_id == expected_root_id
  437. assert parent_id == res.parent.parent.parent.id
  438. # ids(i=2)
  439. root_id, parent_id, value = res.parent.parent.parent.get(timeout=30)
  440. assert value == 2
  441. assert parent_id == res.parent.parent.parent.parent.id
  442. assert root_id == expected_root_id
  443. # ids(i=1)
  444. root_id, parent_id, value = res.parent.parent.parent.parent.get(
  445. timeout=30)
  446. assert value == 1
  447. assert root_id == expected_root_id
  448. assert parent_id is None
  449. def test_chord_on_error(self, manager):
  450. from celery import states
  451. from .tasks import ExpectedException
  452. import time
  453. if not manager.app.conf.result_backend.startswith('redis'):
  454. raise pytest.skip('Requires redis result backend.')
  455. # Run the chord and wait for the error callback to finish.
  456. c1 = chord(
  457. header=[add.s(1, 2), add.s(3, 4), fail.s()],
  458. body=print_unicode.s('This should not be called').on_error(
  459. chord_error.s()),
  460. )
  461. res = c1()
  462. try:
  463. res.wait(propagate=False)
  464. except ExpectedException:
  465. pass
  466. # Got to wait for children to populate.
  467. while not res.children:
  468. time.sleep(0.1)
  469. try:
  470. res.children[0].children[0].wait(propagate=False)
  471. except ExpectedException:
  472. pass
  473. # Extract the results of the successful tasks from the chord.
  474. #
  475. # We could do this inside the error handler, and probably would in a
  476. # real system, but for the purposes of the test it's obnoxious to get
  477. # data out of the error handler.
  478. #
  479. # So for clarity of our test, we instead do it here.
  480. # Use the error callback's result to find the failed task.
  481. error_callback_result = AsyncResult(
  482. res.children[0].children[0].result[0])
  483. failed_task_id = error_callback_result.result.args[0].split()[3]
  484. # Use new group_id result metadata to get group ID.
  485. failed_task_result = AsyncResult(failed_task_id)
  486. original_group_id = failed_task_result._get_task_meta()['group_id']
  487. # Use group ID to get preserved group result.
  488. backend = fail.app.backend
  489. j_key = backend.get_key_for_group(original_group_id, '.j')
  490. redis_connection = get_redis_connection()
  491. chord_results = [backend.decode(t) for t in
  492. redis_connection.lrange(j_key, 0, 3)]
  493. # Validate group result
  494. assert [cr[3] for cr in chord_results if cr[2] == states.SUCCESS] == \
  495. [3, 7]
  496. assert len([cr for cr in chord_results if cr[2] != states.SUCCESS]
  497. ) == 1
  498. def test_parallel_chords(self, manager):
  499. try:
  500. manager.app.backend.ensure_chords_allowed()
  501. except NotImplementedError as e:
  502. raise pytest.skip(e.args[0])
  503. c1 = chord(group(add.s(1, 2), add.s(3, 4)), tsum.s())
  504. c2 = chord(group(add.s(1, 2), add.s(3, 4)), tsum.s())
  505. g = group(c1, c2)
  506. r = g.delay()
  507. assert r.get(timeout=TIMEOUT) == [10, 10]