from __future__ import absolute_import, unicode_literals

import pytest
from case import Mock, patch
from kombu.exceptions import EncodeError

from celery import group, signals, states, uuid
from celery.app.task import Context
from celery.app.trace import (TraceInfo, _fast_trace_task, _trace_task_ret,
                              build_tracer, get_log_policy, get_task_name,
                              log_policy_expected, log_policy_ignore,
                              log_policy_internal, log_policy_reject,
                              log_policy_unexpected,
                              reset_worker_optimizations,
                              setup_worker_optimizations, trace_task)
from celery.exceptions import Ignore, Reject, Retry


def trace(app, task, args=(), kwargs={},
          propagate=False, eager=True, request=None, **opts):
    t = build_tracer(task.name, task,
                     eager=eager, propagate=propagate, app=app, **opts)
    ret = t('id-1', args, kwargs, request)
    return ret.retval, ret.info


class TraceCase:

    def setup(self):
        @self.app.task(shared=False)
        def add(x, y):
            return x + y
        self.add = add

        @self.app.task(shared=False, ignore_result=True)
        def add_cast(x, y):
            return x + y
        self.add_cast = add_cast

        @self.app.task(shared=False)
        def raises(exc):
            raise exc
        self.raises = raises

    def trace(self, *args, **kwargs):
        return trace(self.app, *args, **kwargs)


class test_trace(TraceCase):

    def test_trace_successful(self):
        retval, info = self.trace(self.add, (2, 2), {})
        assert info is None
        assert retval == 4

    def test_trace_on_success(self):

        @self.app.task(shared=False, on_success=Mock())
        def add_with_success(x, y):
            return x + y

        self.trace(add_with_success, (2, 2), {})
        add_with_success.on_success.assert_called()

    def test_get_log_policy(self):
        einfo = Mock(name='einfo')
        einfo.internal = False
        assert get_log_policy(self.add, einfo, Reject()) is log_policy_reject
        assert get_log_policy(self.add, einfo, Ignore()) is log_policy_ignore

        self.add.throws = (TypeError,)
        assert (get_log_policy(self.add, einfo, KeyError()) is
                log_policy_unexpected)
        assert (get_log_policy(self.add, einfo, TypeError()) is
                log_policy_expected)

        einfo2 = Mock(name='einfo2')
        einfo2.internal = True
        assert (get_log_policy(self.add, einfo2, KeyError()) is
                log_policy_internal)

    def test_get_task_name(self):
        assert get_task_name(Context({}), 'default') == 'default'
        assert get_task_name(Context({'shadow': None}), 'default') == 'default'
        assert get_task_name(Context({'shadow': ''}), 'default') == 'default'
        assert get_task_name(Context({'shadow': 'test'}), 'default') == 'test'

    def test_trace_after_return(self):

        @self.app.task(shared=False, after_return=Mock())
        def add_with_after_return(x, y):
            return x + y

        self.trace(add_with_after_return, (2, 2), {})
        add_with_after_return.after_return.assert_called()

    def test_with_prerun_receivers(self):
        on_prerun = Mock()
        signals.task_prerun.connect(on_prerun)
        try:
            self.trace(self.add, (2, 2), {})
            on_prerun.assert_called()
        finally:
            signals.task_prerun.receivers[:] = []

    def test_with_postrun_receivers(self):
        on_postrun = Mock()
        signals.task_postrun.connect(on_postrun)
        try:
            self.trace(self.add, (2, 2), {})
            on_postrun.assert_called()
        finally:
            signals.task_postrun.receivers[:] = []

    def test_with_success_receivers(self):
        on_success = Mock()
        signals.task_success.connect(on_success)
        try:
            self.trace(self.add, (2, 2), {})
            on_success.assert_called()
        finally:
            signals.task_success.receivers[:] = []

    def test_when_chord_part(self):

        @self.app.task(shared=False)
        def add(x, y):
            return x + y
        add.backend = Mock()

        request = {'chord': uuid()}
        self.trace(add, (2, 2), {}, request=request)
        add.backend.mark_as_done.assert_called()
        args, kwargs = add.backend.mark_as_done.call_args
        assert args[0] == 'id-1'
        assert args[1] == 4
        assert args[2].chord == request['chord']
        assert not args[3]

    def test_when_backend_cleanup_raises(self):

        @self.app.task(shared=False)
        def add(x, y):
            return x + y
        add.backend = Mock(name='backend')
        add.backend.process_cleanup.side_effect = KeyError()
        self.trace(add, (2, 2), {}, eager=False)
        add.backend.process_cleanup.assert_called_with()
        add.backend.process_cleanup.side_effect = MemoryError()
        with pytest.raises(MemoryError):
            self.trace(add, (2, 2), {}, eager=False)

    def test_when_Ignore(self):

        @self.app.task(shared=False)
        def ignored():
            raise Ignore()

        retval, info = self.trace(ignored, (), {})
        assert info.state == states.IGNORED

    def test_when_Reject(self):

        @self.app.task(shared=False)
        def rejecting():
            raise Reject()

        retval, info = self.trace(rejecting, (), {})
        assert info.state == states.REJECTED

    def test_backend_cleanup_raises(self):
        self.add.backend.process_cleanup = Mock()
        self.add.backend.process_cleanup.side_effect = RuntimeError()
        self.trace(self.add, (2, 2), {})

    @patch('celery.canvas.maybe_signature')
    def test_callbacks__scalar(self, maybe_signature):
        sig = Mock(name='sig')
        request = {'callbacks': [sig], 'root_id': 'root'}
        maybe_signature.return_value = sig
        retval, _ = self.trace(self.add, (2, 2), {}, request=request)
        sig.apply_async.assert_called_with(
            (4,), parent_id='id-1', root_id='root',
        )

    @patch('celery.canvas.maybe_signature')
    def test_chain_proto2(self, maybe_signature):
        sig = Mock(name='sig')
        sig2 = Mock(name='sig2')
        request = {'chain': [sig2, sig], 'root_id': 'root'}
        maybe_signature.return_value = sig
        retval, _ = self.trace(self.add, (2, 2), {}, request=request)
        sig.apply_async.assert_called_with(
            (4, ), parent_id='id-1', root_id='root',
            chain=[sig2],
        )

    @patch('celery.canvas.maybe_signature')
    def test_callbacks__EncodeError(self, maybe_signature):
        sig = Mock(name='sig')
        request = {'callbacks': [sig], 'root_id': 'root'}
        maybe_signature.return_value = sig
        sig.apply_async.side_effect = EncodeError()
        retval, einfo = self.trace(self.add, (2, 2), {}, request=request)
        assert einfo.state == states.FAILURE

    @patch('celery.canvas.maybe_signature')
    @patch('celery.app.trace.group.apply_async')
    def test_callbacks__sigs(self, group_, maybe_signature):
        sig1 = Mock(name='sig')
        sig2 = Mock(name='sig2')
        sig3 = group([Mock(name='g1'), Mock(name='g2')], app=self.app)
        sig3.apply_async = Mock(name='gapply')
        request = {'callbacks': [sig1, sig3, sig2], 'root_id': 'root'}

        def passt(s, *args, **kwargs):
            return s
        maybe_signature.side_effect = passt
        retval, _ = self.trace(self.add, (2, 2), {}, request=request)
        group_.assert_called_with(
            (4,), parent_id='id-1', root_id='root',
        )
        sig3.apply_async.assert_called_with(
            (4,), parent_id='id-1', root_id='root',
        )

    @patch('celery.canvas.maybe_signature')
    @patch('celery.app.trace.group.apply_async')
    def test_callbacks__only_groups(self, group_, maybe_signature):
        sig1 = group([Mock(name='g1'), Mock(name='g2')], app=self.app)
        sig2 = group([Mock(name='g3'), Mock(name='g4')], app=self.app)
        sig1.apply_async = Mock(name='gapply')
        sig2.apply_async = Mock(name='gapply')
        request = {'callbacks': [sig1, sig2], 'root_id': 'root'}

        def passt(s, *args, **kwargs):
            return s
        maybe_signature.side_effect = passt
        retval, _ = self.trace(self.add, (2, 2), {}, request=request)
        sig1.apply_async.assert_called_with(
            (4,), parent_id='id-1', root_id='root',
        )
        sig2.apply_async.assert_called_with(
            (4,), parent_id='id-1', root_id='root',
        )

    def test_trace_SystemExit(self):
        with pytest.raises(SystemExit):
            self.trace(self.raises, (SystemExit(),), {})

    def test_trace_Retry(self):
        exc = Retry('foo', 'bar')
        _, info = self.trace(self.raises, (exc,), {})
        assert info.state == states.RETRY
        assert info.retval is exc

    def test_trace_exception(self):
        exc = KeyError('foo')
        _, info = self.trace(self.raises, (exc,), {})
        assert info.state == states.FAILURE
        assert info.retval is exc

    def test_trace_task_ret__no_content_type(self):
        _trace_task_ret(
            self.add.name, 'id1', {}, ((2, 2), {}, {}), None, None,
            app=self.app,
        )

    def test_fast_trace_task__no_content_type(self):
        self.app.tasks[self.add.name].__trace__ = build_tracer(
            self.add.name, self.add, app=self.app,
        )
        _fast_trace_task(
            self.add.name, 'id1', {}, ((2, 2), {}, {}), None, None,
            app=self.app, _loc=[self.app.tasks, {}, 'hostname']
        )

    def test_trace_exception_propagate(self):
        with pytest.raises(KeyError):
            self.trace(self.raises, (KeyError('foo'),), {}, propagate=True)

    @patch('celery.app.trace.build_tracer')
    @patch('celery.app.trace.report_internal_error')
    def test_outside_body_error(self, report_internal_error, build_tracer):
        tracer = Mock()
        tracer.side_effect = KeyError('foo')
        build_tracer.return_value = tracer

        @self.app.task(shared=False)
        def xtask():
            pass

        trace_task(xtask, 'uuid', (), {})
        assert report_internal_error.call_count
        assert xtask.__trace__ is tracer


class test_TraceInfo(TraceCase):

    class TI(TraceInfo):
        __slots__ = TraceInfo.__slots__ + ('__dict__',)

    def test_handle_error_state(self):
        x = self.TI(states.FAILURE)
        x.handle_failure = Mock()
        x.handle_error_state(self.add_cast, self.add_cast.request)
        x.handle_failure.assert_called_with(
            self.add_cast, self.add_cast.request,
            store_errors=self.add_cast.store_errors_even_if_ignored,
            call_errbacks=True,
        )

    @patch('celery.app.trace.ExceptionInfo')
    def test_handle_reject(self, ExceptionInfo):
        x = self.TI(states.FAILURE)
        x._log_error = Mock(name='log_error')
        req = Mock(name='req')
        x.handle_reject(self.add, req)
        x._log_error.assert_called_with(self.add, req, ExceptionInfo())


class test_stackprotection:

    def test_stackprotection(self):
        setup_worker_optimizations(self.app)
        try:
            @self.app.task(shared=False, bind=True)
            def foo(self, i):
                if i:
                    return foo(0)
                return self.request

            assert foo(1).called_directly
        finally:
            reset_worker_optimizations()