from __future__ import absolute_import, unicode_literals

import errno
import signal
import sys

import pytest
from case import Mock, call, patch, skip

from celery.apps.multi import (Cluster, MultiParser, NamespacedOptionParser,
                               Node, format_opt)


class test_functions:

    def test_parse_ns_range(self):
        m = MultiParser()
        assert m._parse_ns_range('1-3', True), ['1', '2' == '3']
        assert m._parse_ns_range('1-3', False) == ['1-3']
        assert m._parse_ns_range('1-3,10,11,20', True) == [
            '1', '2', '3', '10', '11', '20',
        ]

    def test_format_opt(self):
        assert format_opt('--foo', None) == '--foo'
        assert format_opt('-c', 1) == '-c 1'
        assert format_opt('--log', 'foo') == '--log=foo'


class test_NamespacedOptionParser:

    def test_parse(self):
        x = NamespacedOptionParser(['-c:1,3', '4'])
        x.parse()
        assert x.namespaces.get('1,3') == {'-c': '4'}
        x = NamespacedOptionParser(['-c:jerry,elaine', '5',
                                    '--loglevel:kramer=DEBUG',
                                    '--flag',
                                    '--logfile=foo', '-Q', 'bar', 'a', 'b',
                                    '--', '.disable_rate_limits=1'])
        x.parse()
        assert x.options == {
            '--logfile': 'foo',
            '-Q': 'bar',
            '--flag': None,
        }
        assert x.values, ['a' == 'b']
        assert x.namespaces.get('jerry,elaine') == {'-c': '5'}
        assert x.namespaces.get('kramer') == {'--loglevel': 'DEBUG'}
        assert x.passthrough == '-- .disable_rate_limits=1'


def multi_args(p, *args, **kwargs):
    return MultiParser(*args, **kwargs).parse(p)


class test_multi_args:

    @patch('celery.apps.multi.gethostname')
    def test_parse(self, gethostname):
        gethostname.return_value = 'example.com'
        p = NamespacedOptionParser([
            '-c:jerry,elaine', '5',
            '--loglevel:kramer=DEBUG',
            '--flag',
            '--logfile=foo', '-Q', 'bar', 'jerry',
            'elaine', 'kramer',
            '--', '.disable_rate_limits=1',
        ])
        p.parse()
        it = multi_args(p, cmd='COMMAND', append='*AP*',
                        prefix='*P*', suffix='*S*')
        nodes = list(it)

        def assert_line_in(name, args):
            assert name in {n.name for n in nodes}
            argv = None
            for node in nodes:
                if node.name == name:
                    argv = node.argv
            assert argv
            for arg in args:
                assert arg in argv

        assert_line_in(
            '*P*jerry@*S*',
            ['COMMAND', '-n *P*jerry@*S*', '-Q bar',
             '-c 5', '--flag', '--logfile=foo',
             '-- .disable_rate_limits=1', '*AP*'],
        )
        assert_line_in(
            '*P*elaine@*S*',
            ['COMMAND', '-n *P*elaine@*S*', '-Q bar',
             '-c 5', '--flag', '--logfile=foo',
             '-- .disable_rate_limits=1', '*AP*'],
        )
        assert_line_in(
            '*P*kramer@*S*',
            ['COMMAND', '--loglevel=DEBUG', '-n *P*kramer@*S*',
             '-Q bar', '--flag', '--logfile=foo',
             '-- .disable_rate_limits=1', '*AP*'],
        )
        expand = nodes[0].expander
        assert expand('%h') == '*P*jerry@*S*'
        assert expand('%n') == '*P*jerry'
        nodes2 = list(multi_args(p, cmd='COMMAND', append='',
                                 prefix='*P*', suffix='*S*'))
        assert nodes2[0].argv[-1] == '-- .disable_rate_limits=1'

        p2 = NamespacedOptionParser(['10', '-c:1', '5'])
        p2.parse()
        nodes3 = list(multi_args(p2, cmd='COMMAND'))

        def _args(name, *args):
            return args + (
                '--pidfile={0}.pid'.format(name),
                '--logfile={0}%I.log'.format(name),
                '--executable={0}'.format(sys.executable),
                '',
            )

        assert len(nodes3) == 10
        assert nodes3[0].name == 'celery1@example.com'
        assert nodes3[0].argv == (
            'COMMAND', '-c 5', '-n celery1@example.com') + _args('celery1')
        for i, worker in enumerate(nodes3[1:]):
            assert worker.name == 'celery%s@example.com' % (i + 2)
            node_i = 'celery%s' % (i + 2,)
            assert worker.argv == (
                'COMMAND',
                '-n %s@example.com' % (node_i,)) + _args(node_i)

        nodes4 = list(multi_args(p2, cmd='COMMAND', suffix='""'))
        assert len(nodes4) == 10
        assert nodes4[0].name == 'celery1@'
        assert nodes4[0].argv == (
            'COMMAND', '-c 5', '-n celery1@') + _args('celery1')

        p3 = NamespacedOptionParser(['foo@', '-c:foo', '5'])
        p3.parse()
        nodes5 = list(multi_args(p3, cmd='COMMAND', suffix='""'))
        assert nodes5[0].name == 'foo@'
        assert nodes5[0].argv == (
            'COMMAND', '-c 5', '-n foo@') + _args('foo')

        p4 = NamespacedOptionParser(['foo', '-Q:1', 'test'])
        p4.parse()
        nodes6 = list(multi_args(p4, cmd='COMMAND', suffix='""'))
        assert nodes6[0].name == 'foo@'
        assert nodes6[0].argv == (
            'COMMAND', '-Q test', '-n foo@') + _args('foo')

        p5 = NamespacedOptionParser(['foo@bar', '-Q:1', 'test'])
        p5.parse()
        nodes7 = list(multi_args(p5, cmd='COMMAND', suffix='""'))
        assert nodes7[0].name == 'foo@bar'
        assert nodes7[0].argv == (
            'COMMAND', '-Q test', '-n foo@bar') + _args('foo')

        p6 = NamespacedOptionParser(['foo@bar', '-Q:0', 'test'])
        p6.parse()
        with pytest.raises(KeyError):
            list(multi_args(p6))

    def test_optmerge(self):
        p = NamespacedOptionParser(['foo', 'test'])
        p.parse()
        p.options = {'x': 'y'}
        r = p.optmerge('foo')
        assert r['x'] == 'y'


class test_Node:

    def setup(self):
        self.p = Mock(name='p')
        self.p.options = {
            '--executable': 'python',
            '--logfile': 'foo.log',
        }
        self.p.namespaces = {}
        self.node = Node('foo@bar.com', options={'-A': 'proj'})
        self.expander = self.node.expander = Mock(name='expander')
        self.node.pid = 303

    def test_from_kwargs(self):
        n = Node.from_kwargs(
            'foo@bar.com',
            max_tasks_per_child=30, A='foo', Q='q1,q2', O='fair',
        )
        assert sorted(n.argv) == sorted([
            '-m celery worker --detach',
            '-A foo',
            '--executable={0}'.format(n.executable),
            '-O fair',
            '-n foo@bar.com',
            '--logfile=foo%I.log',
            '-Q q1,q2',
            '--max-tasks-per-child=30',
            '--pidfile=foo.pid',
            '',
        ])

    @patch('os.kill')
    def test_send(self, kill):
        assert self.node.send(9)
        kill.assert_called_with(self.node.pid, 9)

    @patch('os.kill')
    def test_send__ESRCH(self, kill):
        kill.side_effect = OSError()
        kill.side_effect.errno = errno.ESRCH
        assert not self.node.send(9)
        kill.assert_called_with(self.node.pid, 9)

    @patch('os.kill')
    def test_send__error(self, kill):
        kill.side_effect = OSError()
        kill.side_effect.errno = errno.ENOENT
        with pytest.raises(OSError):
            self.node.send(9)
        kill.assert_called_with(self.node.pid, 9)

    def test_alive(self):
        self.node.send = Mock(name='send')
        assert self.node.alive() is self.node.send.return_value
        self.node.send.assert_called_with(0)

    def test_start(self):
        self.node._waitexec = Mock(name='_waitexec')
        self.node.start(env={'foo': 'bar'}, kw=2)
        self.node._waitexec.assert_called_with(
            self.node.argv, path=self.node.executable,
            env={'foo': 'bar'}, kw=2,
        )

    @patch('celery.apps.multi.Popen')
    def test_waitexec(self, Popen, argv=['A', 'B']):
        on_spawn = Mock(name='on_spawn')
        on_signalled = Mock(name='on_signalled')
        on_failure = Mock(name='on_failure')
        env = Mock(name='env')
        self.node.handle_process_exit = Mock(name='handle_process_exit')

        self.node._waitexec(
            argv,
            path='python',
            env=env,
            on_spawn=on_spawn,
            on_signalled=on_signalled,
            on_failure=on_failure,
        )

        Popen.assert_called_with(
            self.node.prepare_argv(argv, 'python'), env=env)
        self.node.handle_process_exit.assert_called_with(
            Popen().wait(),
            on_signalled=on_signalled,
            on_failure=on_failure,
        )

    def test_handle_process_exit(self):
        assert self.node.handle_process_exit(0) == 0

    def test_handle_process_exit__failure(self):
        on_failure = Mock(name='on_failure')
        assert self.node.handle_process_exit(9, on_failure=on_failure) == 9
        on_failure.assert_called_with(self.node, 9)

    def test_handle_process_exit__signalled(self):
        on_signalled = Mock(name='on_signalled')
        assert self.node.handle_process_exit(
            -9, on_signalled=on_signalled) == 9
        on_signalled.assert_called_with(self.node, 9)

    def test_logfile(self):
        assert self.node.logfile == self.expander.return_value
        self.expander.assert_called_with('%n%I.log')


class test_Cluster:

    def setup(self):
        self.Popen = self.patching('celery.apps.multi.Popen')
        self.kill = self.patching('os.kill')
        self.gethostname = self.patching('celery.apps.multi.gethostname')
        self.gethostname.return_value = 'example.com'
        self.Pidfile = self.patching('celery.apps.multi.Pidfile')
        self.cluster = Cluster(
            [Node('foo@example.com'),
             Node('bar@example.com'),
             Node('baz@example.com')],
            on_stopping_preamble=Mock(name='on_stopping_preamble'),
            on_send_signal=Mock(name='on_send_signal'),
            on_still_waiting_for=Mock(name='on_still_waiting_for'),
            on_still_waiting_progress=Mock(name='on_still_waiting_progress'),
            on_still_waiting_end=Mock(name='on_still_waiting_end'),
            on_node_start=Mock(name='on_node_start'),
            on_node_restart=Mock(name='on_node_restart'),
            on_node_shutdown_ok=Mock(name='on_node_shutdown_ok'),
            on_node_status=Mock(name='on_node_status'),
            on_node_signal=Mock(name='on_node_signal'),
            on_node_signal_dead=Mock(name='on_node_signal_dead'),
            on_node_down=Mock(name='on_node_down'),
            on_child_spawn=Mock(name='on_child_spawn'),
            on_child_signalled=Mock(name='on_child_signalled'),
            on_child_failure=Mock(name='on_child_failure'),
        )

    def test_len(self):
        assert len(self.cluster) == 3

    def test_getitem(self):
        assert self.cluster[0].name == 'foo@example.com'

    def test_start(self):
        self.cluster.start_node = Mock(name='start_node')
        self.cluster.start()
        self.cluster.start_node.assert_has_calls(
            call(node) for node in self.cluster
        )

    def test_start_node(self):
        self.cluster._start_node = Mock(name='_start_node')
        node = self.cluster[0]
        assert (self.cluster.start_node(node) is
                self.cluster._start_node.return_value)
        self.cluster.on_node_start.assert_called_with(node)
        self.cluster._start_node.assert_called_with(node)
        self.cluster.on_node_status.assert_called_with(
            node, self.cluster._start_node(),
        )

    def test__start_node(self):
        node = self.cluster[0]
        node.start = Mock(name='node.start')
        assert self.cluster._start_node(node) is node.start.return_value
        node.start.assert_called_with(
            self.cluster.env,
            on_spawn=self.cluster.on_child_spawn,
            on_signalled=self.cluster.on_child_signalled,
            on_failure=self.cluster.on_child_failure,
        )

    def test_send_all(self):
        nodes = [Mock(name='n1'), Mock(name='n2')]
        self.cluster.getpids = Mock(name='getpids')
        self.cluster.getpids.return_value = nodes
        self.cluster.send_all(15)
        self.cluster.on_node_signal.assert_has_calls(
            call(node, 'TERM') for node in nodes
        )
        for node in nodes:
            node.send.assert_called_with(15, self.cluster.on_node_signal_dead)

    @skip.if_win32()
    def test_kill(self):
        self.cluster.send_all = Mock(name='.send_all')
        self.cluster.kill()
        self.cluster.send_all.assert_called_with(signal.SIGKILL)

    def test_getpids(self):
        self.gethostname.return_value = 'e.com'
        self.prepare_pidfile_for_getpids(self.Pidfile)
        callback = Mock()

        p = Cluster([
            Node('foo@e.com'),
            Node('bar@e.com'),
            Node('baz@e.com'),
        ])
        nodes = p.getpids(on_down=callback)
        node_0, node_1 = nodes
        assert node_0.name == 'foo@e.com'
        assert sorted(node_0.argv) == sorted([
            '',
            '--executable={0}'.format(node_0.executable),
            '--logfile=foo%I.log',
            '--pidfile=foo.pid',
            '-m celery worker --detach',
            '-n foo@e.com',
        ])
        assert node_0.pid == 10

        assert node_1.name == 'bar@e.com'
        assert sorted(node_1.argv) == sorted([
            '',
            '--executable={0}'.format(node_1.executable),
            '--logfile=bar%I.log',
            '--pidfile=bar.pid',
            '-m celery worker --detach',
            '-n bar@e.com',
        ])
        assert node_1.pid == 11

        # without callback, should work
        nodes = p.getpids('celery worker')

    def prepare_pidfile_for_getpids(self, Pidfile):
        class pids(object):

            def __init__(self, path):
                self.path = path

            def read_pid(self):
                try:
                    return {'foo.pid': 10,
                            'bar.pid': 11}[self.path]
                except KeyError:
                    raise ValueError()
        self.Pidfile.side_effect = pids