123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280 |
- import pytest
- from contextlib import contextmanager
- from case import Mock
- from celery import group, uuid
- from celery import canvas
- from celery import result
- from celery.exceptions import ChordError, Retry
- from celery.result import AsyncResult, GroupResult, EagerResult
- def passthru(x):
- return x
- class ChordCase:
- def setup(self):
- @self.app.task(shared=False)
- def add(x, y):
- return x + y
- self.add = add
- class TSR(GroupResult):
- is_ready = True
- value = None
- def ready(self):
- return self.is_ready
- def join(self, propagate=True, **kwargs):
- if propagate:
- for value in self.value:
- if isinstance(value, Exception):
- raise value
- return self.value
- join_native = join
- def _failed_join_report(self):
- for value in self.value:
- if isinstance(value, Exception):
- yield EagerResult('some_id', value, 'FAILURE')
- class TSRNoReport(TSR):
- def _failed_join_report(self):
- return iter([])
- @contextmanager
- def patch_unlock_retry(app):
- unlock = app.tasks['celery.chord_unlock']
- retry = Mock()
- retry.return_value = Retry()
- prev, unlock.retry = unlock.retry, retry
- try:
- yield unlock, retry
- finally:
- unlock.retry = prev
- class test_unlock_chord_task(ChordCase):
- def test_unlock_ready(self):
- class AlwaysReady(TSR):
- is_ready = True
- value = [2, 4, 8, 6]
- with self._chord_context(AlwaysReady) as (cb, retry, _):
- cb.type.apply_async.assert_called_with(
- ([2, 4, 8, 6],), {}, task_id=cb.id,
- )
- # didn't retry
- assert not retry.call_count
- def test_deps_ready_fails(self):
- GroupResult = Mock(name='GroupResult')
- GroupResult.return_value.ready.side_effect = KeyError('foo')
- unlock_chord = self.app.tasks['celery.chord_unlock']
- with pytest.raises(KeyError):
- unlock_chord('groupid', Mock(), result=[Mock()],
- GroupResult=GroupResult, result_from_tuple=Mock())
- def test_callback_fails(self):
- class AlwaysReady(TSR):
- is_ready = True
- value = [2, 4, 8, 6]
- def setup(callback):
- callback.apply_async.side_effect = IOError()
- with self._chord_context(AlwaysReady, setup) as (cb, retry, fail):
- fail.assert_called()
- assert fail.call_args[0][0] == cb.id
- assert isinstance(fail.call_args[1]['exc'], ChordError)
- def test_unlock_ready_failed(self):
- class Failed(TSR):
- is_ready = True
- value = [2, KeyError('foo'), 8, 6]
- with self._chord_context(Failed) as (cb, retry, fail_current):
- cb.type.apply_async.assert_not_called()
- # didn't retry
- assert not retry.call_count
- fail_current.assert_called()
- assert fail_current.call_args[0][0] == cb.id
- assert isinstance(fail_current.call_args[1]['exc'], ChordError)
- assert 'some_id' in str(fail_current.call_args[1]['exc'])
- def test_unlock_ready_failed_no_culprit(self):
- class Failed(TSRNoReport):
- is_ready = True
- value = [2, KeyError('foo'), 8, 6]
- with self._chord_context(Failed) as (cb, retry, fail_current):
- fail_current.assert_called()
- assert fail_current.call_args[0][0] == cb.id
- assert isinstance(fail_current.call_args[1]['exc'], ChordError)
- @contextmanager
- def _chord_context(self, ResultCls, setup=None, **kwargs):
- @self.app.task(shared=False)
- def callback(*args, **kwargs):
- pass
- self.app.finalize()
- pts, result.GroupResult = result.GroupResult, ResultCls
- callback.apply_async = Mock()
- callback_s = callback.s()
- callback_s.id = 'callback_id'
- fail_current = self.app.backend.fail_from_current_stack = Mock()
- try:
- with patch_unlock_retry(self.app) as (unlock, retry):
- signature, canvas.maybe_signature = (
- canvas.maybe_signature, passthru,
- )
- if setup:
- setup(callback)
- try:
- assert self.app.tasks['celery.chord_unlock'] is unlock
- try:
- unlock(
- 'group_id', callback_s,
- result=[
- self.app.AsyncResult(r) for r in ['1', 2, 3]
- ],
- GroupResult=ResultCls, **kwargs
- )
- except Retry:
- pass
- finally:
- canvas.maybe_signature = signature
- yield callback_s, retry, fail_current
- finally:
- result.GroupResult = pts
- def test_when_not_ready(self):
- class NeverReady(TSR):
- is_ready = False
- with self._chord_context(NeverReady, interval=10, max_retries=30) \
- as (cb, retry, _):
- cb.type.apply_async.assert_not_called()
- # did retry
- retry.assert_called_with(countdown=10, max_retries=30)
- def test_is_in_registry(self):
- assert 'celery.chord_unlock' in self.app.tasks
- class test_chord(ChordCase):
- def test_eager(self):
- from celery import chord
- @self.app.task(shared=False)
- def addX(x, y):
- return x + y
- @self.app.task(shared=False)
- def sumX(n):
- return sum(n)
- self.app.conf.task_always_eager = True
- x = chord(addX.s(i, i) for i in range(10))
- body = sumX.s()
- result = x(body)
- assert result.get() == sum(i + i for i in range(10))
- def test_apply(self):
- self.app.conf.task_always_eager = False
- from celery import chord
- m = Mock()
- m.app.conf.task_always_eager = False
- m.AsyncResult = AsyncResult
- prev, chord.run = chord.run, m
- try:
- x = chord(self.add.s(i, i) for i in range(10))
- body = self.add.s(2)
- result = x(body)
- assert result.id
- # does not modify original signature
- with pytest.raises(KeyError):
- body.options['task_id']
- chord.run.assert_called()
- finally:
- chord.run = prev
- class test_add_to_chord:
- def setup(self):
- @self.app.task(shared=False)
- def add(x, y):
- return x + y
- self.add = add
- @self.app.task(shared=False, bind=True)
- def adds(self, sig, lazy=False):
- return self.add_to_chord(sig, lazy)
- self.adds = adds
- def test_add_to_chord(self):
- self.app.backend = Mock(name='backend')
- sig = self.add.s(2, 2)
- sig.delay = Mock(name='sig.delay')
- self.adds.request.group = uuid()
- self.adds.request.id = uuid()
- with pytest.raises(ValueError):
- # task not part of chord
- self.adds.run(sig)
- self.adds.request.chord = self.add.s()
- res1 = self.adds.run(sig, True)
- assert res1 == sig
- assert sig.options['task_id']
- assert sig.options['group_id'] == self.adds.request.group
- assert sig.options['chord'] == self.adds.request.chord
- sig.delay.assert_not_called()
- self.app.backend.add_to_chord.assert_called_with(
- self.adds.request.group, sig.freeze(),
- )
- self.app.backend.reset_mock()
- sig2 = self.add.s(4, 4)
- sig2.delay = Mock(name='sig2.delay')
- res2 = self.adds.run(sig2)
- assert res2 == sig2.delay.return_value
- assert sig2.options['task_id']
- assert sig2.options['group_id'] == self.adds.request.group
- assert sig2.options['chord'] == self.adds.request.chord
- sig2.delay.assert_called_with()
- self.app.backend.add_to_chord.assert_called_with(
- self.adds.request.group, sig2.freeze(),
- )
- class test_Chord_task(ChordCase):
- def test_run(self):
- self.app.backend = Mock()
- self.app.backend.cleanup = Mock()
- self.app.backend.cleanup.__name__ = 'cleanup'
- Chord = self.app.tasks['celery.chord']
- body = self.add.signature()
- Chord(group(self.add.signature((i, i)) for i in range(5)), body)
- Chord([self.add.signature((j, j)) for j in range(5)], body)
- assert self.app.backend.apply_chord.call_count == 2
|