test_chord.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. from __future__ import absolute_import
  2. from contextlib import contextmanager
  3. from celery import group
  4. from celery import canvas
  5. from celery import result
  6. from celery.exceptions import ChordError, Retry
  7. from celery.five import range
  8. from celery.result import AsyncResult, GroupResult, EagerResult
  9. from celery.tests.case import AppCase, Mock
  10. passthru = lambda x: x
  11. class ChordCase(AppCase):
  12. def setup(self):
  13. @self.app.task(shared=False)
  14. def add(x, y):
  15. return x + y
  16. self.add = add
  17. class TSR(GroupResult):
  18. is_ready = True
  19. value = None
  20. def ready(self):
  21. return self.is_ready
  22. def join(self, propagate=True, **kwargs):
  23. if propagate:
  24. for value in self.value:
  25. if isinstance(value, Exception):
  26. raise value
  27. return self.value
  28. join_native = join
  29. def _failed_join_report(self):
  30. for value in self.value:
  31. if isinstance(value, Exception):
  32. yield EagerResult('some_id', value, 'FAILURE')
  33. class TSRNoReport(TSR):
  34. def _failed_join_report(self):
  35. return iter([])
  36. @contextmanager
  37. def patch_unlock_retry(app):
  38. unlock = app.tasks['celery.chord_unlock']
  39. retry = Mock()
  40. retry.return_value = Retry()
  41. prev, unlock.retry = unlock.retry, retry
  42. try:
  43. yield unlock, retry
  44. finally:
  45. unlock.retry = prev
  46. class test_unlock_chord_task(ChordCase):
  47. def test_unlock_ready(self):
  48. class AlwaysReady(TSR):
  49. is_ready = True
  50. value = [2, 4, 8, 6]
  51. with self._chord_context(AlwaysReady) as (cb, retry, _):
  52. cb.type.apply_async.assert_called_with(
  53. ([2, 4, 8, 6], ), {}, task_id=cb.id,
  54. )
  55. # did not retry
  56. self.assertFalse(retry.call_count)
  57. def test_callback_fails(self):
  58. class AlwaysReady(TSR):
  59. is_ready = True
  60. value = [2, 4, 8, 6]
  61. def setup(callback):
  62. callback.apply_async.side_effect = IOError()
  63. with self._chord_context(AlwaysReady, setup) as (cb, retry, fail):
  64. self.assertTrue(fail.called)
  65. self.assertEqual(
  66. fail.call_args[0][0], cb.id,
  67. )
  68. self.assertIsInstance(
  69. fail.call_args[1]['exc'], ChordError,
  70. )
  71. def test_unlock_ready_failed(self):
  72. class Failed(TSR):
  73. is_ready = True
  74. value = [2, KeyError('foo'), 8, 6]
  75. with self._chord_context(Failed) as (cb, retry, fail_current):
  76. self.assertFalse(cb.type.apply_async.called)
  77. # did not retry
  78. self.assertFalse(retry.call_count)
  79. self.assertTrue(fail_current.called)
  80. self.assertEqual(
  81. fail_current.call_args[0][0], cb.id,
  82. )
  83. self.assertIsInstance(
  84. fail_current.call_args[1]['exc'], ChordError,
  85. )
  86. self.assertIn('some_id', str(fail_current.call_args[1]['exc']))
  87. def test_unlock_ready_failed_no_culprit(self):
  88. class Failed(TSRNoReport):
  89. is_ready = True
  90. value = [2, KeyError('foo'), 8, 6]
  91. with self._chord_context(Failed) as (cb, retry, fail_current):
  92. self.assertTrue(fail_current.called)
  93. self.assertEqual(
  94. fail_current.call_args[0][0], cb.id,
  95. )
  96. self.assertIsInstance(
  97. fail_current.call_args[1]['exc'], ChordError,
  98. )
  99. @contextmanager
  100. def _chord_context(self, ResultCls, setup=None, **kwargs):
  101. @self.app.task(shared=False)
  102. def callback(*args, **kwargs):
  103. pass
  104. self.app.finalize()
  105. pts, result.GroupResult = result.GroupResult, ResultCls
  106. callback.apply_async = Mock()
  107. callback_s = callback.s()
  108. callback_s.id = 'callback_id'
  109. fail_current = self.app.backend.fail_from_current_stack = Mock()
  110. try:
  111. with patch_unlock_retry(self.app) as (unlock, retry):
  112. subtask, canvas.maybe_signature = (
  113. canvas.maybe_signature, passthru,
  114. )
  115. if setup:
  116. setup(callback)
  117. try:
  118. assert self.app.tasks['celery.chord_unlock'] is unlock
  119. try:
  120. unlock(
  121. 'group_id', callback_s,
  122. result=[
  123. self.app.AsyncResult(r) for r in ['1', 2, 3]
  124. ],
  125. GroupResult=ResultCls, **kwargs
  126. )
  127. except Retry:
  128. pass
  129. finally:
  130. canvas.maybe_signature = subtask
  131. yield callback_s, retry, fail_current
  132. finally:
  133. result.GroupResult = pts
  134. def test_when_not_ready(self):
  135. class NeverReady(TSR):
  136. is_ready = False
  137. with self._chord_context(NeverReady, interval=10, max_retries=30) \
  138. as (cb, retry, _):
  139. self.assertFalse(cb.type.apply_async.called)
  140. # did retry
  141. retry.assert_called_with(countdown=10, max_retries=30)
  142. def test_is_in_registry(self):
  143. self.assertIn('celery.chord_unlock', self.app.tasks)
  144. class test_chord(ChordCase):
  145. def test_eager(self):
  146. from celery import chord
  147. @self.app.task(shared=False)
  148. def addX(x, y):
  149. return x + y
  150. @self.app.task(shared=False)
  151. def sumX(n):
  152. return sum(n)
  153. self.app.conf.CELERY_ALWAYS_EAGER = True
  154. x = chord(addX.s(i, i) for i in range(10))
  155. body = sumX.s()
  156. result = x(body)
  157. self.assertEqual(result.get(), sum(i + i for i in range(10)))
  158. def test_apply(self):
  159. self.app.conf.CELERY_ALWAYS_EAGER = False
  160. from celery import chord
  161. m = Mock()
  162. m.app.conf.CELERY_ALWAYS_EAGER = False
  163. m.AsyncResult = AsyncResult
  164. prev, chord._type = chord._type, m
  165. try:
  166. x = chord(self.add.s(i, i) for i in range(10))
  167. body = self.add.s(2)
  168. result = x(body)
  169. self.assertTrue(result.id)
  170. # does not modify original subtask
  171. with self.assertRaises(KeyError):
  172. body.options['task_id']
  173. self.assertTrue(chord._type.called)
  174. finally:
  175. chord._type = prev
  176. class test_Chord_task(ChordCase):
  177. def test_run(self):
  178. self.app.backend = Mock()
  179. self.app.backend.cleanup = Mock()
  180. self.app.backend.cleanup.__name__ = 'cleanup'
  181. Chord = self.app.tasks['celery.chord']
  182. body = dict()
  183. Chord(group(self.add.subtask((i, i)) for i in range(5)), body)
  184. Chord([self.add.subtask((j, j)) for j in range(5)], body)
  185. self.assertEqual(self.app.backend.apply_chord.call_count, 2)