test_canvas.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523
  1. from __future__ import absolute_import, unicode_literals
  2. from datetime import datetime, timedelta
  3. import pytest
  4. from celery import chain, chord, group
  5. from celery.exceptions import TimeoutError
  6. from celery.result import AsyncResult, GroupResult, ResultSet
  7. from .conftest import flaky, get_active_redis_channels, get_redis_connection
  8. from .tasks import (add, add_chord_to_chord, add_replaced, add_to_all,
  9. add_to_all_to_chord, build_chain_inside_task, collect_ids,
  10. delayed_sum, delayed_sum_with_soft_guard, identity, ids,
  11. print_unicode, redis_echo, second_order_replace1, tsum)
  12. TIMEOUT = 120
  13. class test_chain:
  14. @flaky
  15. def test_simple_chain(self, manager):
  16. c = add.s(4, 4) | add.s(8) | add.s(16)
  17. assert c().get(timeout=TIMEOUT) == 32
  18. @flaky
  19. def test_single_chain(self, manager):
  20. c = chain(add.s(3, 4))()
  21. assert c.get(timeout=TIMEOUT) == 7
  22. @flaky
  23. def test_complex_chain(self, manager):
  24. c = (
  25. add.s(2, 2) | (
  26. add.s(4) | add_replaced.s(8) | add.s(16) | add.s(32)
  27. ) |
  28. group(add.s(i) for i in range(4))
  29. )
  30. res = c()
  31. assert res.get(timeout=TIMEOUT) == [64, 65, 66, 67]
  32. @flaky
  33. def test_group_results_in_chain(self, manager):
  34. # This adds in an explicit test for the special case added in commit
  35. # 1e3fcaa969de6ad32b52a3ed8e74281e5e5360e6
  36. c = (
  37. group(
  38. add.s(1, 2) | group(
  39. add.s(1), add.s(2)
  40. )
  41. )
  42. )
  43. res = c()
  44. assert res.get(timeout=TIMEOUT) == [4, 5]
  45. @flaky
  46. def test_chain_inside_group_receives_arguments(self, manager):
  47. c = (
  48. add.s(5, 6) |
  49. group((add.s(1) | add.s(2), add.s(3)))
  50. )
  51. res = c()
  52. assert res.get(timeout=TIMEOUT) == [14, 14]
  53. @flaky
  54. def test_eager_chain_inside_task(self, manager):
  55. from .tasks import chain_add
  56. prev = chain_add.app.conf.task_always_eager
  57. chain_add.app.conf.task_always_eager = True
  58. chain_add.apply_async(args=(4, 8), throw=True).get()
  59. chain_add.app.conf.task_always_eager = prev
  60. @flaky
  61. def test_group_chord_group_chain(self, manager):
  62. from celery.five import bytes_if_py2
  63. if not manager.app.conf.result_backend.startswith('redis'):
  64. raise pytest.skip('Requires redis result backend.')
  65. redis_connection = get_redis_connection()
  66. redis_connection.delete('redis-echo')
  67. before = group(redis_echo.si('before {}'.format(i)) for i in range(3))
  68. connect = redis_echo.si('connect')
  69. after = group(redis_echo.si('after {}'.format(i)) for i in range(2))
  70. result = (before | connect | after).delay()
  71. result.get(timeout=TIMEOUT)
  72. redis_messages = list(map(
  73. bytes_if_py2,
  74. redis_connection.lrange('redis-echo', 0, -1)
  75. ))
  76. before_items = \
  77. set(map(bytes_if_py2, (b'before 0', b'before 1', b'before 2')))
  78. after_items = set(map(bytes_if_py2, (b'after 0', b'after 1')))
  79. assert set(redis_messages[:3]) == before_items
  80. assert redis_messages[3] == b'connect'
  81. assert set(redis_messages[4:]) == after_items
  82. redis_connection.delete('redis-echo')
  83. @flaky
  84. def test_second_order_replace(self, manager):
  85. from celery.five import bytes_if_py2
  86. if not manager.app.conf.result_backend.startswith('redis'):
  87. raise pytest.skip('Requires redis result backend.')
  88. redis_connection = get_redis_connection()
  89. redis_connection.delete('redis-echo')
  90. result = second_order_replace1.delay()
  91. result.get(timeout=TIMEOUT)
  92. redis_messages = list(map(
  93. bytes_if_py2,
  94. redis_connection.lrange('redis-echo', 0, -1)
  95. ))
  96. expected_messages = [b'In A', b'In B', b'In/Out C', b'Out B', b'Out A']
  97. assert redis_messages == expected_messages
  98. @flaky
  99. def test_parent_ids(self, manager, num=10):
  100. assert manager.inspect().ping()
  101. c = chain(ids.si(i=i) for i in range(num))
  102. c.freeze()
  103. res = c()
  104. try:
  105. res.get(timeout=TIMEOUT)
  106. except TimeoutError:
  107. print(manager.inspect.active())
  108. print(manager.inspect.reserved())
  109. print(manager.inspect.stats())
  110. raise
  111. self.assert_ids(res, num - 1)
  112. def assert_ids(self, res, size):
  113. i, root = size, res
  114. while root.parent:
  115. root = root.parent
  116. node = res
  117. while node:
  118. root_id, parent_id, value = node.get(timeout=30)
  119. assert value == i
  120. if node.parent:
  121. assert parent_id == node.parent.id
  122. assert root_id == root.id
  123. node = node.parent
  124. i -= 1
  125. def test_chord_soft_timeout_recuperation(self, manager):
  126. """Test that if soft timeout happens in task but is managed by task,
  127. chord still get results normally
  128. """
  129. if not manager.app.conf.result_backend.startswith('redis'):
  130. raise pytest.skip('Requires redis result backend.')
  131. c = chord([
  132. # return 3
  133. add.s(1, 2),
  134. # return 0 after managing soft timeout
  135. delayed_sum_with_soft_guard.s(
  136. [100], pause_time=2
  137. ).set(
  138. soft_time_limit=1
  139. ),
  140. ])
  141. result = c(delayed_sum.s(pause_time=0)).get()
  142. assert result == 3
  143. def test_chain_error_handler_with_eta(self, manager):
  144. try:
  145. manager.app.backend.ensure_chords_allowed()
  146. except NotImplementedError as e:
  147. raise pytest.skip(e.args[0])
  148. eta = datetime.utcnow() + timedelta(seconds=10)
  149. c = chain(
  150. group(
  151. add.s(1, 2),
  152. add.s(3, 4),
  153. ),
  154. tsum.s()
  155. ).on_error(print_unicode.s()).apply_async(eta=eta)
  156. result = c.get()
  157. assert result == 10
  158. @flaky
  159. def test_groupresult_serialization(self, manager):
  160. """Test GroupResult is correctly serialized
  161. to save in the result backend"""
  162. try:
  163. manager.app.backend.ensure_chords_allowed()
  164. except NotImplementedError as e:
  165. raise pytest.skip(e.args[0])
  166. async_result = build_chain_inside_task.delay()
  167. result = async_result.get()
  168. assert len(result) == 2
  169. assert isinstance(result[0][1], list)
  170. class test_result_set:
  171. @flaky
  172. def test_result_set(self, manager):
  173. assert manager.inspect().ping()
  174. rs = ResultSet([add.delay(1, 1), add.delay(2, 2)])
  175. assert rs.get(timeout=TIMEOUT) == [2, 4]
  176. class test_group:
  177. @flaky
  178. def test_empty_group_result(self, manager):
  179. if not manager.app.conf.result_backend.startswith('redis'):
  180. raise pytest.skip('Requires redis result backend.')
  181. task = group([])
  182. result = task.apply_async()
  183. GroupResult.save(result)
  184. task = GroupResult.restore(result.id)
  185. assert task.results == []
  186. @flaky
  187. def test_parent_ids(self, manager):
  188. assert manager.inspect().ping()
  189. g = (
  190. ids.si(i=1) |
  191. ids.si(i=2) |
  192. group(ids.si(i=i) for i in range(2, 50))
  193. )
  194. res = g()
  195. expected_root_id = res.parent.parent.id
  196. expected_parent_id = res.parent.id
  197. values = res.get(timeout=TIMEOUT)
  198. for i, r in enumerate(values):
  199. root_id, parent_id, value = r
  200. assert root_id == expected_root_id
  201. assert parent_id == expected_parent_id
  202. assert value == i + 2
  203. @flaky
  204. def test_nested_group(self, manager):
  205. assert manager.inspect().ping()
  206. c = group(
  207. add.si(1, 10),
  208. group(
  209. add.si(1, 100),
  210. group(
  211. add.si(1, 1000),
  212. add.si(1, 2000),
  213. ),
  214. ),
  215. )
  216. res = c()
  217. assert res.get(timeout=TIMEOUT) == [11, 101, 1001, 2001]
  218. def assert_ids(r, expected_value, expected_root_id, expected_parent_id):
  219. root_id, parent_id, value = r.get(timeout=TIMEOUT)
  220. assert expected_value == value
  221. assert root_id == expected_root_id
  222. assert parent_id == expected_parent_id
  223. class test_chord:
  224. @flaky
  225. def test_redis_subscribed_channels_leak(self, manager):
  226. if not manager.app.conf.result_backend.startswith('redis'):
  227. raise pytest.skip('Requires redis result backend.')
  228. manager.app.backend.result_consumer.on_after_fork()
  229. initial_channels = get_active_redis_channels()
  230. initial_channels_count = len(initial_channels)
  231. total_chords = 10
  232. async_results = [
  233. chord([add.s(5, 6), add.s(6, 7)])(delayed_sum.s())
  234. for _ in range(total_chords)
  235. ]
  236. manager.assert_result_tasks_in_progress_or_completed(async_results)
  237. channels_before = get_active_redis_channels()
  238. channels_before_count = len(channels_before)
  239. assert set(channels_before) != set(initial_channels)
  240. assert channels_before_count > initial_channels_count
  241. # The total number of active Redis channels at this point
  242. # is the number of chord header tasks multiplied by the
  243. # total chord tasks, plus the initial channels
  244. # (existing from previous tests).
  245. chord_header_task_count = 2
  246. assert channels_before_count <= \
  247. chord_header_task_count * total_chords + initial_channels_count
  248. result_values = [
  249. result.get(timeout=TIMEOUT)
  250. for result in async_results
  251. ]
  252. assert result_values == [24] * total_chords
  253. channels_after = get_active_redis_channels()
  254. channels_after_count = len(channels_after)
  255. assert channels_after_count == initial_channels_count
  256. assert set(channels_after) == set(initial_channels)
  257. @flaky
  258. def test_replaced_nested_chord(self, manager):
  259. try:
  260. manager.app.backend.ensure_chords_allowed()
  261. except NotImplementedError as e:
  262. raise pytest.skip(e.args[0])
  263. c1 = chord([
  264. chord(
  265. [add.s(1, 2), add_replaced.s(3, 4)],
  266. add_to_all.s(5),
  267. ) | tsum.s(),
  268. chord(
  269. [add_replaced.s(6, 7), add.s(0, 0)],
  270. add_to_all.s(8),
  271. ) | tsum.s(),
  272. ], add_to_all.s(9))
  273. res1 = c1()
  274. assert res1.get(timeout=TIMEOUT) == [29, 38]
  275. @flaky
  276. def test_add_to_chord(self, manager):
  277. if not manager.app.conf.result_backend.startswith('redis'):
  278. raise pytest.skip('Requires redis result backend.')
  279. c = group([add_to_all_to_chord.s([1, 2, 3], 4)]) | identity.s()
  280. res = c()
  281. assert res.get() == [0, 5, 6, 7]
  282. @flaky
  283. def test_add_chord_to_chord(self, manager):
  284. if not manager.app.conf.result_backend.startswith('redis'):
  285. raise pytest.skip('Requires redis result backend.')
  286. c = group([add_chord_to_chord.s([1, 2, 3], 4)]) | identity.s()
  287. res = c()
  288. assert res.get() == [0, 5 + 6 + 7]
  289. @flaky
  290. def test_group_chain(self, manager):
  291. if not manager.app.conf.result_backend.startswith('redis'):
  292. raise pytest.skip('Requires redis result backend.')
  293. c = (
  294. add.s(2, 2) |
  295. group(add.s(i) for i in range(4)) |
  296. add_to_all.s(8)
  297. )
  298. res = c()
  299. assert res.get(timeout=TIMEOUT) == [12, 13, 14, 15]
  300. @flaky
  301. def test_nested_group_chain(self, manager):
  302. try:
  303. manager.app.backend.ensure_chords_allowed()
  304. except NotImplementedError as e:
  305. raise pytest.skip(e.args[0])
  306. if not manager.app.backend.supports_native_join:
  307. raise pytest.skip('Requires native join support.')
  308. c = chain(
  309. add.si(1, 0),
  310. group(
  311. add.si(1, 100),
  312. chain(
  313. add.si(1, 200),
  314. group(
  315. add.si(1, 1000),
  316. add.si(1, 2000),
  317. ),
  318. ),
  319. ),
  320. add.si(1, 10),
  321. )
  322. res = c()
  323. assert res.get(timeout=TIMEOUT) == 11
  324. @flaky
  325. def test_single_task_header(self, manager):
  326. try:
  327. manager.app.backend.ensure_chords_allowed()
  328. except NotImplementedError as e:
  329. raise pytest.skip(e.args[0])
  330. c1 = chord([add.s(2, 5)], body=add_to_all.s(9))
  331. res1 = c1()
  332. assert res1.get(timeout=TIMEOUT) == [16]
  333. c2 = group([add.s(2, 5)]) | add_to_all.s(9)
  334. res2 = c2()
  335. assert res2.get(timeout=TIMEOUT) == [16]
  336. def test_empty_header_chord(self, manager):
  337. try:
  338. manager.app.backend.ensure_chords_allowed()
  339. except NotImplementedError as e:
  340. raise pytest.skip(e.args[0])
  341. c1 = chord([], body=add_to_all.s(9))
  342. res1 = c1()
  343. assert res1.get(timeout=TIMEOUT) == []
  344. c2 = group([]) | add_to_all.s(9)
  345. res2 = c2()
  346. assert res2.get(timeout=TIMEOUT) == []
  347. @flaky
  348. def test_nested_chord(self, manager):
  349. try:
  350. manager.app.backend.ensure_chords_allowed()
  351. except NotImplementedError as e:
  352. raise pytest.skip(e.args[0])
  353. c1 = chord([
  354. chord([add.s(1, 2), add.s(3, 4)], add.s([5])),
  355. chord([add.s(6, 7)], add.s([10]))
  356. ], add_to_all.s(['A']))
  357. res1 = c1()
  358. assert res1.get(timeout=TIMEOUT) == [[3, 7, 5, 'A'], [13, 10, 'A']]
  359. c2 = group([
  360. group([add.s(1, 2), add.s(3, 4)]) | add.s([5]),
  361. group([add.s(6, 7)]) | add.s([10]),
  362. ]) | add_to_all.s(['A'])
  363. res2 = c2()
  364. assert res2.get(timeout=TIMEOUT) == [[3, 7, 5, 'A'], [13, 10, 'A']]
  365. c = group([
  366. group([
  367. group([
  368. group([
  369. add.s(1, 2)
  370. ]) | add.s([3])
  371. ]) | add.s([4])
  372. ]) | add.s([5])
  373. ]) | add.s([6])
  374. res = c()
  375. assert [[[[3, 3], 4], 5], 6] == res.get(timeout=TIMEOUT)
  376. @flaky
  377. def test_parent_ids(self, manager):
  378. if not manager.app.conf.result_backend.startswith('redis'):
  379. raise pytest.skip('Requires redis result backend.')
  380. root = ids.si(i=1)
  381. expected_root_id = root.freeze().id
  382. g = chain(
  383. root, ids.si(i=2),
  384. chord(
  385. group(ids.si(i=i) for i in range(3, 50)),
  386. chain(collect_ids.s(i=50) | ids.si(i=51)),
  387. ),
  388. )
  389. self.assert_parentids_chord(g(), expected_root_id)
  390. @flaky
  391. def test_parent_ids__OR(self, manager):
  392. if not manager.app.conf.result_backend.startswith('redis'):
  393. raise pytest.skip('Requires redis result backend.')
  394. root = ids.si(i=1)
  395. expected_root_id = root.freeze().id
  396. g = (
  397. root |
  398. ids.si(i=2) |
  399. group(ids.si(i=i) for i in range(3, 50)) |
  400. collect_ids.s(i=50) |
  401. ids.si(i=51)
  402. )
  403. self.assert_parentids_chord(g(), expected_root_id)
  404. def assert_parentids_chord(self, res, expected_root_id):
  405. assert isinstance(res, AsyncResult)
  406. assert isinstance(res.parent, AsyncResult)
  407. assert isinstance(res.parent.parent, GroupResult)
  408. assert isinstance(res.parent.parent.parent, AsyncResult)
  409. assert isinstance(res.parent.parent.parent.parent, AsyncResult)
  410. # first we check the last task
  411. assert_ids(res, 51, expected_root_id, res.parent.id)
  412. # then the chord callback
  413. prev, (root_id, parent_id, value) = res.parent.get(timeout=30)
  414. assert value == 50
  415. assert root_id == expected_root_id
  416. # started by one of the chord header tasks.
  417. assert parent_id in res.parent.parent.results
  418. # check what the chord callback recorded
  419. for i, p in enumerate(prev):
  420. root_id, parent_id, value = p
  421. assert root_id == expected_root_id
  422. assert parent_id == res.parent.parent.parent.id
  423. # ids(i=2)
  424. root_id, parent_id, value = res.parent.parent.parent.get(timeout=30)
  425. assert value == 2
  426. assert parent_id == res.parent.parent.parent.parent.id
  427. assert root_id == expected_root_id
  428. # ids(i=1)
  429. root_id, parent_id, value = res.parent.parent.parent.parent.get(
  430. timeout=30)
  431. assert value == 1
  432. assert root_id == expected_root_id
  433. assert parent_id is None