test_chord.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. import pytest
  2. from contextlib import contextmanager
  3. from case import Mock
  4. from celery import group, uuid
  5. from celery import canvas
  6. from celery import result
  7. from celery.exceptions import ChordError, Retry
  8. from celery.result import AsyncResult, GroupResult, EagerResult
  9. def passthru(x):
  10. return x
  11. class ChordCase:
  12. def setup(self):
  13. @self.app.task(shared=False)
  14. def add(x, y):
  15. return x + y
  16. self.add = add
  17. class TSR(GroupResult):
  18. is_ready = True
  19. value = None
  20. def ready(self):
  21. return self.is_ready
  22. def join(self, propagate=True, **kwargs):
  23. if propagate:
  24. for value in self.value:
  25. if isinstance(value, Exception):
  26. raise value
  27. return self.value
  28. join_native = join
  29. def _failed_join_report(self):
  30. for value in self.value:
  31. if isinstance(value, Exception):
  32. yield EagerResult('some_id', value, 'FAILURE')
  33. class TSRNoReport(TSR):
  34. def _failed_join_report(self):
  35. return iter([])
  36. @contextmanager
  37. def patch_unlock_retry(app):
  38. unlock = app.tasks['celery.chord_unlock']
  39. retry = Mock()
  40. retry.return_value = Retry()
  41. prev, unlock.retry = unlock.retry, retry
  42. try:
  43. yield unlock, retry
  44. finally:
  45. unlock.retry = prev
  46. class test_unlock_chord_task(ChordCase):
  47. def test_unlock_ready(self):
  48. class AlwaysReady(TSR):
  49. is_ready = True
  50. value = [2, 4, 8, 6]
  51. with self._chord_context(AlwaysReady) as (cb, retry, _):
  52. cb.type.apply_async.assert_called_with(
  53. ([2, 4, 8, 6],), {}, task_id=cb.id,
  54. )
  55. # didn't retry
  56. assert not retry.call_count
  57. def test_deps_ready_fails(self):
  58. GroupResult = Mock(name='GroupResult')
  59. GroupResult.return_value.ready.side_effect = KeyError('foo')
  60. unlock_chord = self.app.tasks['celery.chord_unlock']
  61. with pytest.raises(KeyError):
  62. unlock_chord('groupid', Mock(), result=[Mock()],
  63. GroupResult=GroupResult, result_from_tuple=Mock())
  64. def test_callback_fails(self):
  65. class AlwaysReady(TSR):
  66. is_ready = True
  67. value = [2, 4, 8, 6]
  68. def setup(callback):
  69. callback.apply_async.side_effect = IOError()
  70. with self._chord_context(AlwaysReady, setup) as (cb, retry, fail):
  71. fail.assert_called()
  72. assert fail.call_args[0][0] == cb.id
  73. assert isinstance(fail.call_args[1]['exc'], ChordError)
  74. def test_unlock_ready_failed(self):
  75. class Failed(TSR):
  76. is_ready = True
  77. value = [2, KeyError('foo'), 8, 6]
  78. with self._chord_context(Failed) as (cb, retry, fail_current):
  79. cb.type.apply_async.assert_not_called()
  80. # didn't retry
  81. assert not retry.call_count
  82. fail_current.assert_called()
  83. assert fail_current.call_args[0][0] == cb.id
  84. assert isinstance(fail_current.call_args[1]['exc'], ChordError)
  85. assert 'some_id' in str(fail_current.call_args[1]['exc'])
  86. def test_unlock_ready_failed_no_culprit(self):
  87. class Failed(TSRNoReport):
  88. is_ready = True
  89. value = [2, KeyError('foo'), 8, 6]
  90. with self._chord_context(Failed) as (cb, retry, fail_current):
  91. fail_current.assert_called()
  92. assert fail_current.call_args[0][0] == cb.id
  93. assert isinstance(fail_current.call_args[1]['exc'], ChordError)
  94. @contextmanager
  95. def _chord_context(self, ResultCls, setup=None, **kwargs):
  96. @self.app.task(shared=False)
  97. def callback(*args, **kwargs):
  98. pass
  99. self.app.finalize()
  100. pts, result.GroupResult = result.GroupResult, ResultCls
  101. callback.apply_async = Mock()
  102. callback_s = callback.s()
  103. callback_s.id = 'callback_id'
  104. fail_current = self.app.backend.fail_from_current_stack = Mock()
  105. try:
  106. with patch_unlock_retry(self.app) as (unlock, retry):
  107. signature, canvas.maybe_signature = (
  108. canvas.maybe_signature, passthru,
  109. )
  110. if setup:
  111. setup(callback)
  112. try:
  113. assert self.app.tasks['celery.chord_unlock'] is unlock
  114. try:
  115. unlock(
  116. 'group_id', callback_s,
  117. result=[
  118. self.app.AsyncResult(r) for r in ['1', 2, 3]
  119. ],
  120. GroupResult=ResultCls, **kwargs
  121. )
  122. except Retry:
  123. pass
  124. finally:
  125. canvas.maybe_signature = signature
  126. yield callback_s, retry, fail_current
  127. finally:
  128. result.GroupResult = pts
  129. def test_when_not_ready(self):
  130. class NeverReady(TSR):
  131. is_ready = False
  132. with self._chord_context(NeverReady, interval=10, max_retries=30) \
  133. as (cb, retry, _):
  134. cb.type.apply_async.assert_not_called()
  135. # did retry
  136. retry.assert_called_with(countdown=10, max_retries=30)
  137. def test_is_in_registry(self):
  138. assert 'celery.chord_unlock' in self.app.tasks
  139. class test_chord(ChordCase):
  140. def test_eager(self):
  141. from celery import chord
  142. @self.app.task(shared=False)
  143. def addX(x, y):
  144. return x + y
  145. @self.app.task(shared=False)
  146. def sumX(n):
  147. return sum(n)
  148. self.app.conf.task_always_eager = True
  149. x = chord(addX.s(i, i) for i in range(10))
  150. body = sumX.s()
  151. result = x(body)
  152. assert result.get() == sum(i + i for i in range(10))
  153. def test_apply(self):
  154. self.app.conf.task_always_eager = False
  155. from celery import chord
  156. m = Mock()
  157. m.app.conf.task_always_eager = False
  158. m.AsyncResult = AsyncResult
  159. prev, chord.run = chord.run, m
  160. try:
  161. x = chord(self.add.s(i, i) for i in range(10))
  162. body = self.add.s(2)
  163. result = x(body)
  164. assert result.id
  165. # does not modify original signature
  166. with pytest.raises(KeyError):
  167. body.options['task_id']
  168. chord.run.assert_called()
  169. finally:
  170. chord.run = prev
  171. class test_add_to_chord:
  172. def setup(self):
  173. @self.app.task(shared=False)
  174. def add(x, y):
  175. return x + y
  176. self.add = add
  177. @self.app.task(shared=False, bind=True)
  178. def adds(self, sig, lazy=False):
  179. return self.add_to_chord(sig, lazy)
  180. self.adds = adds
  181. def test_add_to_chord(self):
  182. self.app.backend = Mock(name='backend')
  183. sig = self.add.s(2, 2)
  184. sig.delay = Mock(name='sig.delay')
  185. self.adds.request.group = uuid()
  186. self.adds.request.id = uuid()
  187. with pytest.raises(ValueError):
  188. # task not part of chord
  189. self.adds.run(sig)
  190. self.adds.request.chord = self.add.s()
  191. res1 = self.adds.run(sig, True)
  192. assert res1 == sig
  193. assert sig.options['task_id']
  194. assert sig.options['group_id'] == self.adds.request.group
  195. assert sig.options['chord'] == self.adds.request.chord
  196. sig.delay.assert_not_called()
  197. self.app.backend.add_to_chord.assert_called_with(
  198. self.adds.request.group, sig.freeze(),
  199. )
  200. self.app.backend.reset_mock()
  201. sig2 = self.add.s(4, 4)
  202. sig2.delay = Mock(name='sig2.delay')
  203. res2 = self.adds.run(sig2)
  204. assert res2 == sig2.delay.return_value
  205. assert sig2.options['task_id']
  206. assert sig2.options['group_id'] == self.adds.request.group
  207. assert sig2.options['chord'] == self.adds.request.chord
  208. sig2.delay.assert_called_with()
  209. self.app.backend.add_to_chord.assert_called_with(
  210. self.adds.request.group, sig2.freeze(),
  211. )
  212. class test_Chord_task(ChordCase):
  213. def test_run(self):
  214. self.app.backend = Mock()
  215. self.app.backend.cleanup = Mock()
  216. self.app.backend.cleanup.__name__ = 'cleanup'
  217. Chord = self.app.tasks['celery.chord']
  218. body = self.add.signature()
  219. Chord(group(self.add.signature((i, i)) for i in range(5)), body)
  220. Chord([self.add.signature((j, j)) for j in range(5)], body)
  221. assert self.app.backend.apply_chord.call_count == 2