test_chord.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. from __future__ import absolute_import, unicode_literals
  2. import pytest
  3. from contextlib import contextmanager
  4. from case import Mock
  5. from celery import group, uuid
  6. from celery import canvas
  7. from celery import result
  8. from celery.exceptions import ChordError, Retry
  9. from celery.five import range
  10. from celery.result import AsyncResult, GroupResult, EagerResult
  11. def passthru(x):
  12. return x
  13. class ChordCase:
  14. def setup(self):
  15. @self.app.task(shared=False)
  16. def add(x, y):
  17. return x + y
  18. self.add = add
  19. class TSR(GroupResult):
  20. is_ready = True
  21. value = None
  22. def ready(self):
  23. return self.is_ready
  24. def join(self, propagate=True, **kwargs):
  25. if propagate:
  26. for value in self.value:
  27. if isinstance(value, Exception):
  28. raise value
  29. return self.value
  30. join_native = join
  31. def _failed_join_report(self):
  32. for value in self.value:
  33. if isinstance(value, Exception):
  34. yield EagerResult('some_id', value, 'FAILURE')
  35. class TSRNoReport(TSR):
  36. def _failed_join_report(self):
  37. return iter([])
  38. @contextmanager
  39. def patch_unlock_retry(app):
  40. unlock = app.tasks['celery.chord_unlock']
  41. retry = Mock()
  42. retry.return_value = Retry()
  43. prev, unlock.retry = unlock.retry, retry
  44. try:
  45. yield unlock, retry
  46. finally:
  47. unlock.retry = prev
  48. class test_unlock_chord_task(ChordCase):
  49. def test_unlock_ready(self):
  50. class AlwaysReady(TSR):
  51. is_ready = True
  52. value = [2, 4, 8, 6]
  53. with self._chord_context(AlwaysReady) as (cb, retry, _):
  54. cb.type.apply_async.assert_called_with(
  55. ([2, 4, 8, 6],), {}, task_id=cb.id,
  56. )
  57. # didn't retry
  58. assert not retry.call_count
  59. def test_deps_ready_fails(self):
  60. GroupResult = Mock(name='GroupResult')
  61. GroupResult.return_value.ready.side_effect = KeyError('foo')
  62. unlock_chord = self.app.tasks['celery.chord_unlock']
  63. with pytest.raises(KeyError):
  64. unlock_chord('groupid', Mock(), result=[Mock()],
  65. GroupResult=GroupResult, result_from_tuple=Mock())
  66. def test_callback_fails(self):
  67. class AlwaysReady(TSR):
  68. is_ready = True
  69. value = [2, 4, 8, 6]
  70. def setup(callback):
  71. callback.apply_async.side_effect = IOError()
  72. with self._chord_context(AlwaysReady, setup) as (cb, retry, fail):
  73. fail.assert_called()
  74. assert fail.call_args[0][0] == cb.id
  75. assert isinstance(fail.call_args[1]['exc'], ChordError)
  76. def test_unlock_ready_failed(self):
  77. class Failed(TSR):
  78. is_ready = True
  79. value = [2, KeyError('foo'), 8, 6]
  80. with self._chord_context(Failed) as (cb, retry, fail_current):
  81. cb.type.apply_async.assert_not_called()
  82. # didn't retry
  83. assert not retry.call_count
  84. fail_current.assert_called()
  85. assert fail_current.call_args[0][0] == cb.id
  86. assert isinstance(fail_current.call_args[1]['exc'], ChordError)
  87. assert 'some_id' in str(fail_current.call_args[1]['exc'])
  88. def test_unlock_ready_failed_no_culprit(self):
  89. class Failed(TSRNoReport):
  90. is_ready = True
  91. value = [2, KeyError('foo'), 8, 6]
  92. with self._chord_context(Failed) as (cb, retry, fail_current):
  93. fail_current.assert_called()
  94. assert fail_current.call_args[0][0] == cb.id
  95. assert isinstance(fail_current.call_args[1]['exc'], ChordError)
  96. @contextmanager
  97. def _chord_context(self, ResultCls, setup=None, **kwargs):
  98. @self.app.task(shared=False)
  99. def callback(*args, **kwargs):
  100. pass
  101. self.app.finalize()
  102. pts, result.GroupResult = result.GroupResult, ResultCls
  103. callback.apply_async = Mock()
  104. callback_s = callback.s()
  105. callback_s.id = 'callback_id'
  106. fail_current = self.app.backend.fail_from_current_stack = Mock()
  107. try:
  108. with patch_unlock_retry(self.app) as (unlock, retry):
  109. signature, canvas.maybe_signature = (
  110. canvas.maybe_signature, passthru,
  111. )
  112. if setup:
  113. setup(callback)
  114. try:
  115. assert self.app.tasks['celery.chord_unlock'] is unlock
  116. try:
  117. unlock(
  118. 'group_id', callback_s,
  119. result=[
  120. self.app.AsyncResult(r) for r in ['1', 2, 3]
  121. ],
  122. GroupResult=ResultCls, **kwargs
  123. )
  124. except Retry:
  125. pass
  126. finally:
  127. canvas.maybe_signature = signature
  128. yield callback_s, retry, fail_current
  129. finally:
  130. result.GroupResult = pts
  131. def test_when_not_ready(self):
  132. class NeverReady(TSR):
  133. is_ready = False
  134. with self._chord_context(NeverReady, interval=10, max_retries=30) \
  135. as (cb, retry, _):
  136. cb.type.apply_async.assert_not_called()
  137. # did retry
  138. retry.assert_called_with(countdown=10, max_retries=30)
  139. def test_is_in_registry(self):
  140. assert 'celery.chord_unlock' in self.app.tasks
  141. class test_chord(ChordCase):
  142. def test_eager(self):
  143. from celery import chord
  144. @self.app.task(shared=False)
  145. def addX(x, y):
  146. return x + y
  147. @self.app.task(shared=False)
  148. def sumX(n):
  149. return sum(n)
  150. self.app.conf.task_always_eager = True
  151. x = chord(addX.s(i, i) for i in range(10))
  152. body = sumX.s()
  153. result = x(body)
  154. assert result.get() == sum(i + i for i in range(10))
  155. def test_apply(self):
  156. self.app.conf.task_always_eager = False
  157. from celery import chord
  158. m = Mock()
  159. m.app.conf.task_always_eager = False
  160. m.AsyncResult = AsyncResult
  161. prev, chord.run = chord.run, m
  162. try:
  163. x = chord(self.add.s(i, i) for i in range(10))
  164. body = self.add.s(2)
  165. result = x(body)
  166. assert result.id
  167. # does not modify original signature
  168. with pytest.raises(KeyError):
  169. body.options['task_id']
  170. chord.run.assert_called()
  171. finally:
  172. chord.run = prev
  173. class test_add_to_chord:
  174. def setup(self):
  175. @self.app.task(shared=False)
  176. def add(x, y):
  177. return x + y
  178. self.add = add
  179. @self.app.task(shared=False, bind=True)
  180. def adds(self, sig, lazy=False):
  181. return self.add_to_chord(sig, lazy)
  182. self.adds = adds
  183. def test_add_to_chord(self):
  184. self.app.backend = Mock(name='backend')
  185. sig = self.add.s(2, 2)
  186. sig.delay = Mock(name='sig.delay')
  187. self.adds.request.group = uuid()
  188. self.adds.request.id = uuid()
  189. with pytest.raises(ValueError):
  190. # task not part of chord
  191. self.adds.run(sig)
  192. self.adds.request.chord = self.add.s()
  193. res1 = self.adds.run(sig, True)
  194. assert res1 == sig
  195. assert sig.options['task_id']
  196. assert sig.options['group_id'] == self.adds.request.group
  197. assert sig.options['chord'] == self.adds.request.chord
  198. sig.delay.assert_not_called()
  199. self.app.backend.add_to_chord.assert_called_with(
  200. self.adds.request.group, sig.freeze(),
  201. )
  202. self.app.backend.reset_mock()
  203. sig2 = self.add.s(4, 4)
  204. sig2.delay = Mock(name='sig2.delay')
  205. res2 = self.adds.run(sig2)
  206. assert res2 == sig2.delay.return_value
  207. assert sig2.options['task_id']
  208. assert sig2.options['group_id'] == self.adds.request.group
  209. assert sig2.options['chord'] == self.adds.request.chord
  210. sig2.delay.assert_called_with()
  211. self.app.backend.add_to_chord.assert_called_with(
  212. self.adds.request.group, sig2.freeze(),
  213. )
  214. class test_Chord_task(ChordCase):
  215. def test_run(self):
  216. self.app.backend = Mock()
  217. self.app.backend.cleanup = Mock()
  218. self.app.backend.cleanup.__name__ = 'cleanup'
  219. Chord = self.app.tasks['celery.chord']
  220. body = self.add.signature()
  221. Chord(group(self.add.signature((i, i)) for i in range(5)), body)
  222. Chord([self.add.signature((j, j)) for j in range(5)], body)
  223. assert self.app.backend.apply_chord.call_count == 2