from __future__ import absolute_import, print_function, unicode_literals

import os
import socket
import sys
from collections import deque
from datetime import datetime, timedelta
from functools import partial
from threading import Event

import pytest
from amqp import ChannelError
from case import Mock, patch, skip
from kombu import Connection
from kombu.common import QoS, ignore_errors
from kombu.transport.base import Message
from kombu.transport.memory import Transport
from kombu.utils.uuid import uuid

from celery.bootsteps import CLOSE, RUN, TERMINATE, StartStopStep
from celery.concurrency.base import BasePool
from celery.exceptions import (ImproperlyConfigured, InvalidTaskError,
                               TaskRevokedError, WorkerShutdown,
                               WorkerTerminate)
from celery.five import Empty
from celery.five import Queue as FastQueue
from celery.five import range
from celery.platforms import EX_FAILURE
from celery.utils.nodenames import worker_direct
from celery.utils.serialization import pickle
from celery.utils.timer2 import Timer
from celery.worker import components, consumer, state
from celery.worker import worker as worker_module
from celery.worker.consumer import Consumer
from celery.worker.pidbox import gPidbox
from celery.worker.request import Request


def MockStep(step=None):
    if step is None:
        step = Mock(name='step')
    else:
        step.blueprint = Mock(name='step.blueprint')
    step.blueprint.name = 'MockNS'
    step.name = 'MockStep(%s)' % (id(step),)
    return step


def mock_event_dispatcher():
    evd = Mock(name='event_dispatcher')
    evd.groups = ['worker']
    evd._outbound_buffer = deque()
    return evd


def find_step(obj, typ):
    return obj.blueprint.steps[typ.name]


def create_message(channel, **data):
    data.setdefault('id', uuid())
    m = Message(body=pickle.dumps(dict(**data)),
                channel=channel,
                content_type='application/x-python-serialize',
                content_encoding='binary',
                delivery_info={'consumer_tag': 'mock'})
    m.accept = ['application/x-python-serialize']
    return m


class ConsumerCase:

    def create_task_message(self, channel, *args, **kwargs):
        m = self.TaskMessage(*args, **kwargs)
        m.channel = channel
        m.delivery_info = {'consumer_tag': 'mock'}
        return m


class test_Consumer(ConsumerCase):

    def setup(self):
        self.buffer = FastQueue()
        self.timer = Timer()

        @self.app.task(shared=False)
        def foo_task(x, y, z):
            return x * y * z
        self.foo_task = foo_task

    def teardown(self):
        self.timer.stop()

    def LoopConsumer(self, buffer=None, controller=None, timer=None, app=None,
                     without_mingle=True, without_gossip=True,
                     without_heartbeat=True, **kwargs):
        if controller is None:
            controller = Mock(name='.controller')
        buffer = buffer if buffer is not None else self.buffer.put
        timer = timer if timer is not None else self.timer
        app = app if app is not None else self.app
        c = Consumer(
            buffer,
            timer=timer,
            app=app,
            controller=controller,
            without_mingle=without_mingle,
            without_gossip=without_gossip,
            without_heartbeat=without_heartbeat,
            **kwargs
        )
        c.task_consumer = Mock(name='.task_consumer')
        c.qos = QoS(c.task_consumer.qos, 10)
        c.connection = Mock(name='.connection')
        c.controller = c.app.WorkController()
        c.heart = Mock(name='.heart')
        c.controller.consumer = c
        c.pool = c.controller.pool = Mock(name='.controller.pool')
        c.node = Mock(name='.node')
        c.event_dispatcher = mock_event_dispatcher()
        return c

    def NoopConsumer(self, *args, **kwargs):
        c = self.LoopConsumer(*args, **kwargs)
        c.loop = Mock(name='.loop')
        return c

    def test_info(self):
        c = self.NoopConsumer()
        c.connection.info.return_value = {'foo': 'bar'}
        c.controller.pool.info.return_value = [Mock(), Mock()]
        info = c.controller.stats()
        assert info['prefetch_count'] == 10
        assert info['broker']

    def test_start_when_closed(self):
        c = self.NoopConsumer()
        c.blueprint.state = CLOSE
        c.start()

    def test_connection(self):
        c = self.NoopConsumer()

        c.blueprint.start(c)
        assert isinstance(c.connection, Connection)

        c.blueprint.state = RUN
        c.event_dispatcher = None
        c.blueprint.restart(c)
        assert c.connection

        c.blueprint.state = RUN
        c.shutdown()
        assert c.connection is None
        assert c.task_consumer is None

        c.blueprint.start(c)
        assert isinstance(c.connection, Connection)
        c.blueprint.restart(c)

        c.stop()
        c.shutdown()
        assert c.connection is None
        assert c.task_consumer is None

    def test_close_connection(self):
        c = self.NoopConsumer()
        c.blueprint.state = RUN
        step = find_step(c, consumer.Connection)
        connection = c.connection
        step.shutdown(c)
        connection.close.assert_called()
        assert c.connection is None

    def test_close_connection__heart_shutdown(self):
        c = self.NoopConsumer()
        event_dispatcher = c.event_dispatcher
        heart = c.heart
        c.event_dispatcher.enabled = True
        c.blueprint.state = RUN
        Events = find_step(c, consumer.Events)
        Events.shutdown(c)
        Heart = find_step(c, consumer.Heart)
        Heart.shutdown(c)
        event_dispatcher.close.assert_called()
        heart.stop.assert_called_with()

    @patch('celery.worker.consumer.consumer.warn')
    def test_receive_message_unknown(self, warn):
        c = self.LoopConsumer()
        c.blueprint.state = RUN
        c.steps.pop()
        channel = Mock(name='.channeol')
        m = create_message(channel, unknown={'baz': '!!!'})

        callback = self._get_on_message(c)
        callback(m)
        warn.assert_called()

    @patch('celery.worker.strategy.to_timestamp')
    def test_receive_message_eta_OverflowError(self, to_timestamp):
        to_timestamp.side_effect = OverflowError()
        c = self.LoopConsumer()
        c.blueprint.state = RUN
        c.steps.pop()
        m = self.create_task_message(
            Mock(), self.foo_task.name,
            args=('2, 2'), kwargs={},
            eta=datetime.now().isoformat(),
        )
        c.update_strategies()
        callback = self._get_on_message(c)
        callback(m)
        assert m.acknowledged

    @patch('celery.worker.consumer.consumer.error')
    def test_receive_message_InvalidTaskError(self, error):
        c = self.LoopConsumer()
        c.blueprint.state = RUN
        c.steps.pop()
        m = self.create_task_message(
            Mock(), self.foo_task.name,
            args=(1, 2), kwargs='foobarbaz', id=1)
        c.update_strategies()
        strat = c.strategies[self.foo_task.name] = Mock(name='strategy')
        strat.side_effect = InvalidTaskError()

        callback = self._get_on_message(c)
        callback(m)
        error.assert_called()
        assert 'Received invalid task message' in error.call_args[0][0]

    @patch('celery.worker.consumer.consumer.crit')
    def test_on_decode_error(self, crit):
        c = self.LoopConsumer()

        class MockMessage(Mock):
            content_type = 'application/x-msgpack'
            content_encoding = 'binary'
            body = 'foobarbaz'

        message = MockMessage()
        c.on_decode_error(message, KeyError('foo'))
        assert message.ack.call_count
        assert "Can't decode message body" in crit.call_args[0][0]

    def _get_on_message(self, c):
        if c.qos is None:
            c.qos = Mock()
        c.task_consumer = Mock()
        c.event_dispatcher = mock_event_dispatcher()
        c.connection = Mock(name='.connection')
        c.connection.get_heartbeat_interval.return_value = 0
        c.connection.drain_events.side_effect = WorkerShutdown()

        with pytest.raises(WorkerShutdown):
            c.loop(*c.loop_args())
        assert c.task_consumer.on_message
        return c.task_consumer.on_message

    def test_receieve_message(self):
        c = self.LoopConsumer()
        c.blueprint.state = RUN
        m = self.create_task_message(
            Mock(), self.foo_task.name,
            args=[2, 4, 8], kwargs={},
        )
        c.update_strategies()
        callback = self._get_on_message(c)
        callback(m)

        in_bucket = self.buffer.get_nowait()
        assert isinstance(in_bucket, Request)
        assert in_bucket.name == self.foo_task.name
        assert in_bucket.execute() == 2 * 4 * 8
        assert self.timer.empty()

    def test_start_channel_error(self):
        c = self.NoopConsumer(task_events=False, pool=BasePool())
        c.loop.on_nth_call_do_raise(KeyError('foo'), SyntaxError('bar'))
        c.channel_errors = (KeyError,)
        try:
            with pytest.raises(KeyError):
                c.start()
        finally:
            c.timer and c.timer.stop()

    def test_start_connection_error(self):
        c = self.NoopConsumer(task_events=False, pool=BasePool())
        c.loop.on_nth_call_do_raise(KeyError('foo'), SyntaxError('bar'))
        c.connection_errors = (KeyError,)
        try:
            with pytest.raises(SyntaxError):
                c.start()
        finally:
            c.timer and c.timer.stop()

    def test_loop_ignores_socket_timeout(self):

        class Connection(self.app.connection_for_read().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None
                raise socket.timeout(10)

        c = self.NoopConsumer()
        c.connection = Connection(self.app.conf.broker_url)
        c.connection.obj = c
        c.qos = QoS(c.task_consumer.qos, 10)
        c.loop(*c.loop_args())

    def test_loop_when_socket_error(self):

        class Connection(self.app.connection_for_read().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None
                raise socket.error('foo')

        c = self.LoopConsumer()
        c.blueprint.state = RUN
        conn = c.connection = Connection(self.app.conf.broker_url)
        c.connection.obj = c
        c.qos = QoS(c.task_consumer.qos, 10)
        with pytest.raises(socket.error):
            c.loop(*c.loop_args())

        c.blueprint.state = CLOSE
        c.connection = conn
        c.loop(*c.loop_args())

    def test_loop(self):

        class Connection(self.app.connection_for_read().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None

            @property
            def supports_heartbeats(self):
                return False

        c = self.LoopConsumer()
        c.blueprint.state = RUN
        c.connection = Connection(self.app.conf.broker_url)
        c.connection.obj = c
        c.connection.get_heartbeat_interval = Mock(return_value=None)
        c.qos = QoS(c.task_consumer.qos, 10)

        c.loop(*c.loop_args())
        c.loop(*c.loop_args())
        assert c.task_consumer.consume.call_count
        c.task_consumer.qos.assert_called_with(prefetch_count=10)
        assert c.qos.value == 10
        c.qos.decrement_eventually()
        assert c.qos.value == 9
        c.qos.update()
        assert c.qos.value == 9
        c.task_consumer.qos.assert_called_with(prefetch_count=9)

    def test_ignore_errors(self):
        c = self.NoopConsumer()
        c.connection_errors = (AttributeError, KeyError,)
        c.channel_errors = (SyntaxError,)
        ignore_errors(c, Mock(side_effect=AttributeError('foo')))
        ignore_errors(c, Mock(side_effect=KeyError('foo')))
        ignore_errors(c, Mock(side_effect=SyntaxError('foo')))
        with pytest.raises(IndexError):
            ignore_errors(c, Mock(side_effect=IndexError('foo')))

    def test_apply_eta_task(self):
        c = self.NoopConsumer()
        c.qos = QoS(None, 10)
        task = Mock(name='task', id='1234213')
        qos = c.qos.value
        c.apply_eta_task(task)
        assert task in state.reserved_requests
        assert c.qos.value == qos - 1
        assert self.buffer.get_nowait() is task

    def test_receieve_message_eta_isoformat(self):
        c = self.LoopConsumer()
        c.blueprint.state = RUN
        c.steps.pop()
        m = self.create_task_message(
            Mock(), self.foo_task.name,
            eta=(datetime.now() + timedelta(days=1)).isoformat(),
            args=[2, 4, 8], kwargs={},
        )

        c.qos = QoS(c.task_consumer.qos, 1)
        current_pcount = c.qos.value
        c.event_dispatcher.enabled = False
        c.update_strategies()
        callback = self._get_on_message(c)
        callback(m)
        c.timer.stop()
        c.timer.join(1)

        items = [entry[2] for entry in self.timer.queue]
        found = 0
        for item in items:
            if item.args[0].name == self.foo_task.name:
                found = True
        assert found
        assert c.qos.value > current_pcount
        c.timer.stop()

    def test_pidbox_callback(self):
        c = self.NoopConsumer()
        con = find_step(c, consumer.Control).box
        con.node = Mock()
        con.reset = Mock()

        con.on_message('foo', 'bar')
        con.node.handle_message.assert_called_with('foo', 'bar')

        con.node = Mock()
        con.node.handle_message.side_effect = KeyError('foo')
        con.on_message('foo', 'bar')
        con.node.handle_message.assert_called_with('foo', 'bar')

        con.node = Mock()
        con.node.handle_message.side_effect = ValueError('foo')
        con.on_message('foo', 'bar')
        con.node.handle_message.assert_called_with('foo', 'bar')
        con.reset.assert_called()

    def test_revoke(self):
        c = self.LoopConsumer()
        c.blueprint.state = RUN
        c.steps.pop()
        channel = Mock(name='channel')
        id = uuid()
        t = self.create_task_message(
            channel, self.foo_task.name,
            args=[2, 4, 8], kwargs={}, id=id,
        )

        state.revoked.add(id)

        callback = self._get_on_message(c)
        callback(t)
        assert self.buffer.empty()

    def test_receieve_message_not_registered(self):
        c = self.LoopConsumer()
        c.blueprint.state = RUN
        c.steps.pop()
        channel = Mock(name='channel')
        m = self.create_task_message(
            channel, 'x.X.31x', args=[2, 4, 8], kwargs={},
        )

        callback = self._get_on_message(c)
        assert not callback(m)
        with pytest.raises(Empty):
            self.buffer.get_nowait()
        assert self.timer.empty()

    @patch('celery.worker.consumer.consumer.warn')
    @patch('celery.worker.consumer.consumer.logger')
    def test_receieve_message_ack_raises(self, logger, warn):
        c = self.LoopConsumer()
        c.blueprint.state = RUN
        channel = Mock(name='channel')
        m = self.create_task_message(
            channel, self.foo_task.name,
            args=[2, 4, 8], kwargs={},
        )
        m.headers = None

        c.update_strategies()
        c.connection_errors = (socket.error,)
        m.reject = Mock()
        m.reject.side_effect = socket.error('foo')
        callback = self._get_on_message(c)
        assert not callback(m)
        warn.assert_called()
        with pytest.raises(Empty):
            self.buffer.get_nowait()
        assert self.timer.empty()
        m.reject_log_error.assert_called_with(logger, c.connection_errors)

    def test_receive_message_eta(self):
        if os.environ.get('C_DEBUG_TEST'):
            pp = partial(print, file=sys.__stderr__)
        else:
            def pp(*args, **kwargs):
                pass
        pp('TEST RECEIVE MESSAGE ETA')
        pp('+CREATE MYKOMBUCONSUMER')
        c = self.LoopConsumer()
        pp('-CREATE MYKOMBUCONSUMER')
        c.steps.pop()
        channel = Mock(name='channel')
        pp('+ CREATE MESSAGE')
        m = self.create_task_message(
            channel, self.foo_task.name,
            args=[2, 4, 8], kwargs={},
            eta=(datetime.now() + timedelta(days=1)).isoformat(),
        )
        pp('- CREATE MESSAGE')

        try:
            pp('+ BLUEPRINT START 1')
            c.blueprint.start(c)
            pp('- BLUEPRINT START 1')
            p = c.app.conf.broker_connection_retry
            c.app.conf.broker_connection_retry = False
            pp('+ BLUEPRINT START 2')
            c.blueprint.start(c)
            pp('- BLUEPRINT START 2')
            c.app.conf.broker_connection_retry = p
            pp('+ BLUEPRINT RESTART')
            c.blueprint.restart(c)
            pp('- BLUEPRINT RESTART')
            pp('+ GET ON MESSAGE')
            callback = self._get_on_message(c)
            pp('- GET ON MESSAGE')
            pp('+ CALLBACK')
            callback(m)
            pp('- CALLBACK')
        finally:
            pp('+ STOP TIMER')
            c.timer.stop()
            pp('- STOP TIMER')
            try:
                pp('+ JOIN TIMER')
                c.timer.join()
                pp('- JOIN TIMER')
            except RuntimeError:
                pass

        in_hold = c.timer.queue[0]
        assert len(in_hold) == 3
        eta, priority, entry = in_hold
        task = entry.args[0]
        assert isinstance(task, Request)
        assert task.name == self.foo_task.name
        assert task.execute() == 2 * 4 * 8
        with pytest.raises(Empty):
            self.buffer.get_nowait()

    def test_reset_pidbox_node(self):
        c = self.NoopConsumer()
        con = find_step(c, consumer.Control).box
        con.node = Mock()
        chan = con.node.channel = Mock()
        chan.close.side_effect = socket.error('foo')
        c.connection_errors = (socket.error,)
        con.reset()
        chan.close.assert_called_with()

    def test_reset_pidbox_node_green(self):
        c = self.NoopConsumer(pool=Mock(is_green=True))
        con = find_step(c, consumer.Control)
        assert isinstance(con.box, gPidbox)
        con.start(c)
        c.pool.spawn_n.assert_called_with(con.box.loop, c)

    def test_green_pidbox_node(self):
        pool = Mock()
        pool.is_green = True
        c = self.NoopConsumer(pool=Mock(is_green=True))
        controller = find_step(c, consumer.Control)

        class BConsumer(Mock):

            def __enter__(self):
                self.consume()
                return self

            def __exit__(self, *exc_info):
                self.cancel()

        controller.box.node.listen = BConsumer()
        connections = []

        class Connection(object):
            calls = 0

            def __init__(self, obj):
                connections.append(self)
                self.obj = obj
                self.default_channel = self.channel()
                self.closed = False

            def __enter__(self):
                return self

            def __exit__(self, *exc_info):
                self.close()

            def channel(self):
                return Mock()

            def as_uri(self):
                return 'dummy://'

            def drain_events(self, **kwargs):
                if not self.calls:
                    self.calls += 1
                    raise socket.timeout()
                self.obj.connection = None
                controller.box._node_shutdown.set()

            def close(self):
                self.closed = True

        c.connection_for_read = lambda: Connection(obj=c)
        controller = find_step(c, consumer.Control)
        controller.box.loop(c)

        controller.box.node.listen.assert_called()
        assert controller.box.consumer
        controller.box.consumer.consume.assert_called_with()

        assert c.connection is None
        assert connections[0].closed

    @patch('kombu.connection.Connection._establish_connection')
    @patch('kombu.utils.functional.sleep')
    def test_connect_errback(self, sleep, connect):
        c = self.NoopConsumer()
        Transport.connection_errors = (ChannelError,)
        connect.on_nth_call_do(ChannelError('error'), n=1)
        c.connect()
        connect.assert_called_with()

    def test_stop_pidbox_node(self):
        c = self.NoopConsumer()
        cont = find_step(c, consumer.Control)
        cont._node_stopped = Event()
        cont._node_shutdown = Event()
        cont._node_stopped.set()
        cont.stop(c)

    def test_start__loop(self):

        class _QoS(object):
            prev = 3
            value = 4

            def update(self):
                self.prev = self.value

        init_callback = Mock(name='init_callback')
        c = self.NoopConsumer(init_callback=init_callback)
        c.qos = _QoS()
        c.connection = Connection(self.app.conf.broker_url)
        c.connection.get_heartbeat_interval = Mock(return_value=None)
        c.iterations = 0

        def raises_KeyError(*args, **kwargs):
            c.iterations += 1
            if c.qos.prev != c.qos.value:
                c.qos.update()
            if c.iterations >= 2:
                raise KeyError('foo')

        c.loop = raises_KeyError
        with pytest.raises(KeyError):
            c.start()
        assert c.iterations == 2
        assert c.qos.prev == c.qos.value

        init_callback.reset_mock()
        c = self.NoopConsumer(task_events=False, init_callback=init_callback)
        c.qos = _QoS()
        c.connection = Connection(self.app.conf.broker_url)
        c.connection.get_heartbeat_interval = Mock(return_value=None)
        c.loop = Mock(side_effect=socket.error('foo'))
        with pytest.raises(socket.error):
            c.start()
        c.loop.assert_called()

    def test_reset_connection_with_no_node(self):
        c = self.NoopConsumer()
        c.steps.pop()
        c.blueprint.start(c)


class test_WorkController(ConsumerCase):

    def setup(self):
        self.worker = self.create_worker()
        self._logger = worker_module.logger
        self._comp_logger = components.logger
        self.logger = worker_module.logger = Mock()
        self.comp_logger = components.logger = Mock()

        @self.app.task(shared=False)
        def foo_task(x, y, z):
            return x * y * z
        self.foo_task = foo_task

    def teardown(self):
        worker_module.logger = self._logger
        components.logger = self._comp_logger

    def create_worker(self, **kw):
        worker = self.app.WorkController(concurrency=1, loglevel=0, **kw)
        worker.blueprint.shutdown_complete.set()
        return worker

    def test_on_consumer_ready(self):
        self.worker.on_consumer_ready(Mock())

    def test_setup_queues_worker_direct(self):
        self.app.conf.worker_direct = True
        self.app.amqp.__dict__['queues'] = Mock()
        self.worker.setup_queues({})
        self.app.amqp.queues.select_add.assert_called_with(
            worker_direct(self.worker.hostname),
        )

    def test_setup_queues__missing_queue(self):
        self.app.amqp.queues.select = Mock(name='select')
        self.app.amqp.queues.deselect = Mock(name='deselect')
        self.app.amqp.queues.select.side_effect = KeyError()
        self.app.amqp.queues.deselect.side_effect = KeyError()
        with pytest.raises(ImproperlyConfigured):
            self.worker.setup_queues('x,y', exclude='foo,bar')
        self.app.amqp.queues.select = Mock(name='select')
        with pytest.raises(ImproperlyConfigured):
            self.worker.setup_queues('x,y', exclude='foo,bar')

    def test_send_worker_shutdown(self):
        with patch('celery.signals.worker_shutdown') as ws:
            self.worker._send_worker_shutdown()
            ws.send.assert_called_with(sender=self.worker)

    @skip.todo('unstable test')
    def test_process_shutdown_on_worker_shutdown(self):
        from celery.concurrency.prefork import process_destructor
        from celery.concurrency.asynpool import Worker
        with patch('celery.signals.worker_process_shutdown') as ws:
            with patch('os._exit') as _exit:
                worker = Worker(None, None, on_exit=process_destructor)
                worker._do_exit(22, 3.1415926)
                ws.send.assert_called_with(
                    sender=None, pid=22, exitcode=3.1415926,
                )
                _exit.assert_called_with(3.1415926)

    def test_process_task_revoked_release_semaphore(self):
        self.worker._quick_release = Mock()
        req = Mock()
        req.execute_using_pool.side_effect = TaskRevokedError
        self.worker._process_task(req)
        self.worker._quick_release.assert_called_with()

        delattr(self.worker, '_quick_release')
        self.worker._process_task(req)

    def test_shutdown_no_blueprint(self):
        self.worker.blueprint = None
        self.worker._shutdown()

    @patch('celery.worker.worker.create_pidlock')
    def test_use_pidfile(self, create_pidlock):
        create_pidlock.return_value = Mock()
        worker = self.create_worker(pidfile='pidfilelockfilepid')
        worker.steps = []
        worker.start()
        create_pidlock.assert_called()
        worker.stop()
        worker.pidlock.release.assert_called()

    def test_attrs(self):
        worker = self.worker
        assert worker.timer is not None
        assert isinstance(worker.timer, Timer)
        assert worker.pool is not None
        assert worker.consumer is not None
        assert worker.steps

    def test_with_embedded_beat(self):
        worker = self.app.WorkController(concurrency=1, loglevel=0, beat=True)
        assert worker.beat
        assert worker.beat in [w.obj for w in worker.steps]

    def test_with_autoscaler(self):
        worker = self.create_worker(
            autoscale=[10, 3], send_events=False,
            timer_cls='celery.utils.timer2.Timer',
        )
        assert worker.autoscaler

    def test_dont_stop_or_terminate(self):
        worker = self.app.WorkController(concurrency=1, loglevel=0)
        worker.stop()
        assert worker.blueprint.state != CLOSE
        worker.terminate()
        assert worker.blueprint.state != CLOSE

        sigsafe, worker.pool.signal_safe = worker.pool.signal_safe, False
        try:
            worker.blueprint.state = RUN
            worker.stop(in_sighandler=True)
            assert worker.blueprint.state != CLOSE
            worker.terminate(in_sighandler=True)
            assert worker.blueprint.state != CLOSE
        finally:
            worker.pool.signal_safe = sigsafe

    def test_on_timer_error(self):
        worker = self.app.WorkController(concurrency=1, loglevel=0)

        try:
            raise KeyError('foo')
        except KeyError as exc:
            components.Timer(worker).on_timer_error(exc)
            msg, args = self.comp_logger.error.call_args[0]
            assert 'KeyError' in msg % args

    def test_on_timer_tick(self):
        worker = self.app.WorkController(concurrency=1, loglevel=10)

        components.Timer(worker).on_timer_tick(30.0)
        xargs = self.comp_logger.debug.call_args[0]
        fmt, arg = xargs[0], xargs[1]
        assert arg == 30.0
        assert 'Next ETA %s secs' in fmt

    def test_process_task(self):
        worker = self.worker
        worker.pool = Mock()
        channel = Mock()
        m = self.create_task_message(
            channel, self.foo_task.name,
            args=[4, 8, 10], kwargs={},
        )
        task = Request(m, app=self.app)
        worker._process_task(task)
        assert worker.pool.apply_async.call_count == 1
        worker.pool.stop()

    def test_process_task_raise_base(self):
        worker = self.worker
        worker.pool = Mock()
        worker.pool.apply_async.side_effect = KeyboardInterrupt('Ctrl+C')
        channel = Mock()
        m = self.create_task_message(
            channel, self.foo_task.name,
            args=[4, 8, 10], kwargs={},
        )
        task = Request(m, app=self.app)
        worker.steps = []
        worker.blueprint.state = RUN
        with pytest.raises(KeyboardInterrupt):
            worker._process_task(task)

    def test_process_task_raise_WorkerTerminate(self):
        worker = self.worker
        worker.pool = Mock()
        worker.pool.apply_async.side_effect = WorkerTerminate()
        channel = Mock()
        m = self.create_task_message(
            channel, self.foo_task.name,
            args=[4, 8, 10], kwargs={},
        )
        task = Request(m, app=self.app)
        worker.steps = []
        worker.blueprint.state = RUN
        with pytest.raises(SystemExit):
            worker._process_task(task)

    def test_process_task_raise_regular(self):
        worker = self.worker
        worker.pool = Mock()
        worker.pool.apply_async.side_effect = KeyError('some exception')
        channel = Mock()
        m = self.create_task_message(
            channel, self.foo_task.name,
            args=[4, 8, 10], kwargs={},
        )
        task = Request(m, app=self.app)
        with pytest.raises(KeyError):
            worker._process_task(task)
        worker.pool.stop()

    def test_start_catches_base_exceptions(self):
        worker1 = self.create_worker()
        worker1.blueprint.state = RUN
        stc = MockStep()
        stc.start.side_effect = WorkerTerminate()
        worker1.steps = [stc]
        worker1.start()
        stc.start.assert_called_with(worker1)
        assert stc.terminate.call_count

        worker2 = self.create_worker()
        worker2.blueprint.state = RUN
        sec = MockStep()
        sec.start.side_effect = WorkerShutdown()
        sec.terminate = None
        worker2.steps = [sec]
        worker2.start()
        assert sec.stop.call_count

    def test_statedb(self):
        from celery.worker import state
        Persistent = state.Persistent

        state.Persistent = Mock()
        try:
            worker = self.create_worker(statedb='statefilename')
            assert worker._persistence
        finally:
            state.Persistent = Persistent

    def test_process_task_sem(self):
        worker = self.worker
        worker._quick_acquire = Mock()

        req = Mock()
        worker._process_task_sem(req)
        worker._quick_acquire.assert_called_with(worker._process_task, req)

    def test_signal_consumer_close(self):
        worker = self.worker
        worker.consumer = Mock()

        worker.signal_consumer_close()
        worker.consumer.close.assert_called_with()

        worker.consumer.close.side_effect = AttributeError()
        worker.signal_consumer_close()

    def test_rusage__no_resource(self):
        from celery.worker import worker
        prev, worker.resource = worker.resource, None
        try:
            self.worker.pool = Mock(name='pool')
            with pytest.raises(NotImplementedError):
                self.worker.rusage()
            self.worker.stats()
        finally:
            worker.resource = prev

    def test_repr(self):
        assert repr(self.worker)

    def test_str(self):
        assert str(self.worker) == self.worker.hostname

    def test_start__stop(self):
        worker = self.worker
        worker.blueprint.shutdown_complete.set()
        worker.steps = [MockStep(StartStopStep(self)) for _ in range(4)]
        worker.blueprint.state = RUN
        worker.blueprint.started = 4
        for w in worker.steps:
            w.start = Mock()
            w.close = Mock()
            w.stop = Mock()

        worker.start()
        for w in worker.steps:
            w.start.assert_called()
        worker.consumer = Mock()
        worker.stop(exitcode=3)
        for stopstep in worker.steps:
            stopstep.close.assert_called()
            stopstep.stop.assert_called()

        # Doesn't close pool if no pool.
        worker.start()
        worker.pool = None
        worker.stop()

        # test that stop of None is not attempted
        worker.steps[-1] = None
        worker.start()
        worker.stop()

    def test_start__KeyboardInterrupt(self):
        worker = self.worker
        worker.blueprint = Mock(name='blueprint')
        worker.blueprint.start.side_effect = KeyboardInterrupt()
        worker.stop = Mock(name='stop')
        worker.start()
        worker.stop.assert_called_with(exitcode=EX_FAILURE)

    def test_register_with_event_loop(self):
        worker = self.worker
        hub = Mock(name='hub')
        worker.blueprint = Mock(name='blueprint')
        worker.register_with_event_loop(hub)
        worker.blueprint.send_all.assert_called_with(
            worker, 'register_with_event_loop', args=(hub,),
            description='hub.register',
        )

    def test_step_raises(self):
        worker = self.worker
        step = Mock()
        worker.steps = [step]
        step.start.side_effect = TypeError()
        worker.stop = Mock()
        worker.start()
        worker.stop.assert_called_with(exitcode=EX_FAILURE)

    def test_state(self):
        assert self.worker.state

    def test_start__terminate(self):
        worker = self.worker
        worker.blueprint.shutdown_complete.set()
        worker.blueprint.started = 5
        worker.blueprint.state = RUN
        worker.steps = [MockStep() for _ in range(5)]
        worker.start()
        for w in worker.steps[:3]:
            w.start.assert_called()
        assert worker.blueprint.started == len(worker.steps)
        assert worker.blueprint.state == RUN
        worker.terminate()
        for step in worker.steps:
            step.terminate.assert_called()
        worker.blueprint.state = TERMINATE
        worker.terminate()

    def test_Hub_create(self):
        w = Mock()
        x = components.Hub(w)
        x.create(w)
        assert w.timer.max_interval

    def test_Pool_create_threaded(self):
        w = Mock()
        w._conninfo.connection_errors = w._conninfo.channel_errors = ()
        w.pool_cls = Mock()
        w.use_eventloop = False
        pool = components.Pool(w)
        pool.create(w)

    def test_Pool_pool_no_sem(self):
        w = Mock()
        w.pool_cls.uses_semaphore = False
        components.Pool(w).create(w)
        assert w.process_task is w._process_task

    def test_Pool_create(self):
        from kombu.asynchronous.semaphore import LaxBoundedSemaphore
        w = Mock()
        w._conninfo.connection_errors = w._conninfo.channel_errors = ()
        w.hub = Mock()

        PoolImp = Mock()
        poolimp = PoolImp.return_value = Mock()
        poolimp._pool = [Mock(), Mock()]
        poolimp._cache = {}
        poolimp._fileno_to_inq = {}
        poolimp._fileno_to_outq = {}

        from celery.concurrency.prefork import TaskPool as _TaskPool

        class MockTaskPool(_TaskPool):
            Pool = PoolImp

            @property
            def timers(self):
                return {Mock(): 30}

        w.pool_cls = MockTaskPool
        w.use_eventloop = True
        w.consumer.restart_count = -1
        pool = components.Pool(w)
        pool.create(w)
        pool.register_with_event_loop(w, w.hub)
        if sys.platform != 'win32':
            assert isinstance(w.semaphore, LaxBoundedSemaphore)
            P = w.pool
            P.start()