from __future__ import absolute_import, unicode_literals import pytest from case import Mock, patch from kombu.exceptions import EncodeError from celery import group, signals, states, uuid from celery.app.task import Context from celery.app.trace import (TraceInfo, _fast_trace_task, _trace_task_ret, build_tracer, get_log_policy, get_task_name, log_policy_expected, log_policy_ignore, log_policy_internal, log_policy_reject, log_policy_unexpected, reset_worker_optimizations, setup_worker_optimizations, trace_task) from celery.exceptions import Ignore, Reject, Retry def trace(app, task, args=(), kwargs={}, propagate=False, eager=True, request=None, **opts): t = build_tracer(task.name, task, eager=eager, propagate=propagate, app=app, **opts) ret = t('id-1', args, kwargs, request) return ret.retval, ret.info class TraceCase: def setup(self): @self.app.task(shared=False) def add(x, y): return x + y self.add = add @self.app.task(shared=False, ignore_result=True) def add_cast(x, y): return x + y self.add_cast = add_cast @self.app.task(shared=False) def raises(exc): raise exc self.raises = raises def trace(self, *args, **kwargs): return trace(self.app, *args, **kwargs) class test_trace(TraceCase): def test_trace_successful(self): retval, info = self.trace(self.add, (2, 2), {}) assert info is None assert retval == 4 def test_trace_on_success(self): @self.app.task(shared=False, on_success=Mock()) def add_with_success(x, y): return x + y self.trace(add_with_success, (2, 2), {}) add_with_success.on_success.assert_called() def test_get_log_policy(self): einfo = Mock(name='einfo') einfo.internal = False assert get_log_policy(self.add, einfo, Reject()) is log_policy_reject assert get_log_policy(self.add, einfo, Ignore()) is log_policy_ignore self.add.throws = (TypeError,) assert (get_log_policy(self.add, einfo, KeyError()) is log_policy_unexpected) assert (get_log_policy(self.add, einfo, TypeError()) is log_policy_expected) einfo2 = Mock(name='einfo2') einfo2.internal = True assert (get_log_policy(self.add, einfo2, KeyError()) is log_policy_internal) def test_get_task_name(self): assert get_task_name(Context({}), 'default') == 'default' assert get_task_name(Context({'shadow': None}), 'default') == 'default' assert get_task_name(Context({'shadow': ''}), 'default') == 'default' assert get_task_name(Context({'shadow': 'test'}), 'default') == 'test' def test_trace_after_return(self): @self.app.task(shared=False, after_return=Mock()) def add_with_after_return(x, y): return x + y self.trace(add_with_after_return, (2, 2), {}) add_with_after_return.after_return.assert_called() def test_with_prerun_receivers(self): on_prerun = Mock() signals.task_prerun.connect(on_prerun) try: self.trace(self.add, (2, 2), {}) on_prerun.assert_called() finally: signals.task_prerun.receivers[:] = [] def test_with_postrun_receivers(self): on_postrun = Mock() signals.task_postrun.connect(on_postrun) try: self.trace(self.add, (2, 2), {}) on_postrun.assert_called() finally: signals.task_postrun.receivers[:] = [] def test_with_success_receivers(self): on_success = Mock() signals.task_success.connect(on_success) try: self.trace(self.add, (2, 2), {}) on_success.assert_called() finally: signals.task_success.receivers[:] = [] def test_when_chord_part(self): @self.app.task(shared=False) def add(x, y): return x + y add.backend = Mock() request = {'chord': uuid()} self.trace(add, (2, 2), {}, request=request) add.backend.mark_as_done.assert_called() args, kwargs = add.backend.mark_as_done.call_args assert args[0] == 'id-1' assert args[1] == 4 assert args[2].chord == request['chord'] assert not args[3] def test_when_backend_cleanup_raises(self): @self.app.task(shared=False) def add(x, y): return x + y add.backend = Mock(name='backend') add.backend.process_cleanup.side_effect = KeyError() self.trace(add, (2, 2), {}, eager=False) add.backend.process_cleanup.assert_called_with() add.backend.process_cleanup.side_effect = MemoryError() with pytest.raises(MemoryError): self.trace(add, (2, 2), {}, eager=False) def test_when_Ignore(self): @self.app.task(shared=False) def ignored(): raise Ignore() retval, info = self.trace(ignored, (), {}) assert info.state == states.IGNORED def test_when_Reject(self): @self.app.task(shared=False) def rejecting(): raise Reject() retval, info = self.trace(rejecting, (), {}) assert info.state == states.REJECTED def test_backend_cleanup_raises(self): self.add.backend.process_cleanup = Mock() self.add.backend.process_cleanup.side_effect = RuntimeError() self.trace(self.add, (2, 2), {}) @patch('celery.canvas.maybe_signature') def test_callbacks__scalar(self, maybe_signature): sig = Mock(name='sig') request = {'callbacks': [sig], 'root_id': 'root'} maybe_signature.return_value = sig retval, _ = self.trace(self.add, (2, 2), {}, request=request) sig.apply_async.assert_called_with( (4,), parent_id='id-1', root_id='root', ) @patch('celery.canvas.maybe_signature') def test_chain_proto2(self, maybe_signature): sig = Mock(name='sig') sig2 = Mock(name='sig2') request = {'chain': [sig2, sig], 'root_id': 'root'} maybe_signature.return_value = sig retval, _ = self.trace(self.add, (2, 2), {}, request=request) sig.apply_async.assert_called_with( (4, ), parent_id='id-1', root_id='root', chain=[sig2], ) @patch('celery.canvas.maybe_signature') def test_callbacks__EncodeError(self, maybe_signature): sig = Mock(name='sig') request = {'callbacks': [sig], 'root_id': 'root'} maybe_signature.return_value = sig sig.apply_async.side_effect = EncodeError() retval, einfo = self.trace(self.add, (2, 2), {}, request=request) assert einfo.state == states.FAILURE @patch('celery.canvas.maybe_signature') @patch('celery.app.trace.group.apply_async') def test_callbacks__sigs(self, group_, maybe_signature): sig1 = Mock(name='sig') sig2 = Mock(name='sig2') sig3 = group([Mock(name='g1'), Mock(name='g2')], app=self.app) sig3.apply_async = Mock(name='gapply') request = {'callbacks': [sig1, sig3, sig2], 'root_id': 'root'} def passt(s, *args, **kwargs): return s maybe_signature.side_effect = passt retval, _ = self.trace(self.add, (2, 2), {}, request=request) group_.assert_called_with( (4,), parent_id='id-1', root_id='root', ) sig3.apply_async.assert_called_with( (4,), parent_id='id-1', root_id='root', ) @patch('celery.canvas.maybe_signature') @patch('celery.app.trace.group.apply_async') def test_callbacks__only_groups(self, group_, maybe_signature): sig1 = group([Mock(name='g1'), Mock(name='g2')], app=self.app) sig2 = group([Mock(name='g3'), Mock(name='g4')], app=self.app) sig1.apply_async = Mock(name='gapply') sig2.apply_async = Mock(name='gapply') request = {'callbacks': [sig1, sig2], 'root_id': 'root'} def passt(s, *args, **kwargs): return s maybe_signature.side_effect = passt retval, _ = self.trace(self.add, (2, 2), {}, request=request) sig1.apply_async.assert_called_with( (4,), parent_id='id-1', root_id='root', ) sig2.apply_async.assert_called_with( (4,), parent_id='id-1', root_id='root', ) def test_trace_SystemExit(self): with pytest.raises(SystemExit): self.trace(self.raises, (SystemExit(),), {}) def test_trace_Retry(self): exc = Retry('foo', 'bar') _, info = self.trace(self.raises, (exc,), {}) assert info.state == states.RETRY assert info.retval is exc def test_trace_exception(self): exc = KeyError('foo') _, info = self.trace(self.raises, (exc,), {}) assert info.state == states.FAILURE assert info.retval is exc def test_trace_task_ret__no_content_type(self): _trace_task_ret( self.add.name, 'id1', {}, ((2, 2), {}, {}), None, None, app=self.app, ) def test_fast_trace_task__no_content_type(self): self.app.tasks[self.add.name].__trace__ = build_tracer( self.add.name, self.add, app=self.app, ) _fast_trace_task( self.add.name, 'id1', {}, ((2, 2), {}, {}), None, None, app=self.app, _loc=[self.app.tasks, {}, 'hostname'] ) def test_trace_exception_propagate(self): with pytest.raises(KeyError): self.trace(self.raises, (KeyError('foo'),), {}, propagate=True) @patch('celery.app.trace.build_tracer') @patch('celery.app.trace.report_internal_error') def test_outside_body_error(self, report_internal_error, build_tracer): tracer = Mock() tracer.side_effect = KeyError('foo') build_tracer.return_value = tracer @self.app.task(shared=False) def xtask(): pass trace_task(xtask, 'uuid', (), {}) assert report_internal_error.call_count assert xtask.__trace__ is tracer class test_TraceInfo(TraceCase): class TI(TraceInfo): __slots__ = TraceInfo.__slots__ + ('__dict__',) def test_handle_error_state(self): x = self.TI(states.FAILURE) x.handle_failure = Mock() x.handle_error_state(self.add_cast, self.add_cast.request) x.handle_failure.assert_called_with( self.add_cast, self.add_cast.request, store_errors=self.add_cast.store_errors_even_if_ignored, call_errbacks=True, ) @patch('celery.app.trace.ExceptionInfo') def test_handle_reject(self, ExceptionInfo): x = self.TI(states.FAILURE) x._log_error = Mock(name='log_error') req = Mock(name='req') x.handle_reject(self.add, req) x._log_error.assert_called_with(self.add, req, ExceptionInfo()) class test_stackprotection: def test_stackprotection(self): setup_worker_optimizations(self.app) try: @self.app.task(shared=False, bind=True) def foo(self, i): if i: return foo(0) return self.request assert foo(1).called_directly finally: reset_worker_optimizations()