| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474 | from __future__ import absolute_import, unicode_literalsimport errnoimport socketimport pytestfrom case import Mockfrom kombu.async import Hub, READ, WRITE, ERRfrom celery.bootsteps import CLOSE, RUNfrom celery.exceptions import (    InvalidTaskError, WorkerLostError, WorkerShutdown, WorkerTerminate,)from celery.five import Empty, python_2_unicode_compatiblefrom celery.platforms import EX_FAILUREfrom celery.worker import statefrom celery.worker.consumer import Consumerfrom celery.worker.loops import _quick_drain, asynloop, synloop@python_2_unicode_compatibleclass PromiseEqual(object):    def __init__(self, fun, *args, **kwargs):        self.fun = fun        self.args = args        self.kwargs = kwargs    def __eq__(self, other):        return (other.fun == self.fun and                other.args == self.args and                other.kwargs == self.kwargs)    def __repr__(self):        return '<promise: {0.fun!r} {0.args!r} {0.kwargs!r}>'.format(self)class X(object):    def __init__(self, app, heartbeat=None, on_task_message=None,                 transport_driver_type=None):        hub = Hub()        (            self.obj,            self.connection,            self.consumer,            self.blueprint,            self.hub,            self.qos,            self.heartbeat,            self.clock,        ) = self.args = [Mock(name='obj'),                         Mock(name='connection'),                         Mock(name='consumer'),                         Mock(name='blueprint'),                         hub,                         Mock(name='qos'),                         heartbeat,                         Mock(name='clock')]        self.connection.supports_heartbeats = True        self.connection.get_heartbeat_interval.side_effect = (            lambda: self.heartbeat        )        self.consumer.callbacks = []        self.obj.strategies = {}        self.connection.connection_errors = (socket.error,)        if transport_driver_type:            self.connection.transport.driver_type = transport_driver_type        self.hub.readers = {}        self.hub.timer = Mock(name='hub.timer')        self.hub.timer._queue = [Mock()]        self.hub.fire_timers = Mock(name='hub.fire_timers')        self.hub.fire_timers.return_value = 1.7        self.hub.poller = Mock(name='hub.poller')        self.hub.close = Mock(name='hub.close()')  # asynloop calls hub.close        self.Hub = self.hub        self.blueprint.state = RUN        # need this for create_task_handler        self._consumer = _consumer = Consumer(            Mock(), timer=Mock(), controller=Mock(), app=app)        _consumer.on_task_message = on_task_message or []        self.obj.create_task_handler = _consumer.create_task_handler        self.on_unknown_message = self.obj.on_unknown_message = Mock(            name='on_unknown_message',        )        _consumer.on_unknown_message = self.on_unknown_message        self.on_unknown_task = self.obj.on_unknown_task = Mock(            name='on_unknown_task',        )        _consumer.on_unknown_task = self.on_unknown_task        self.on_invalid_task = self.obj.on_invalid_task = Mock(            name='on_invalid_task',        )        _consumer.on_invalid_task = self.on_invalid_task        _consumer.strategies = self.obj.strategies    def timeout_then_error(self, mock):        def first(*args, **kwargs):            mock.side_effect = socket.error()            raise socket.timeout()        mock.side_effect = first    def close_then_error(self, mock=None, mod=0, exc=None):        mock = Mock() if mock is None else mock        def first(*args, **kwargs):            if not mod or mock.call_count > mod:                self.close()                raise (socket.error() if exc is None else exc)        mock.side_effect = first        return mock    def close(self, *args, **kwargs):        self.blueprint.state = CLOSE    def closer(self, mock=None, mod=0):        mock = Mock() if mock is None else mock        def closing(*args, **kwargs):            if not mod or mock.call_count >= mod:                self.close()        mock.side_effect = closing        return mockdef get_task_callback(*args, **kwargs):    x = X(*args, **kwargs)    x.blueprint.state = CLOSE    asynloop(*x.args)    return x, x.consumer.on_messageclass test_asynloop:    def setup(self):        @self.app.task(shared=False)        def add(x, y):            return x + y        self.add = add    def test_drain_after_consume(self):        x, _ = get_task_callback(self.app, transport_driver_type='amqp')        assert _quick_drain in [p.fun for p in x.hub._ready]    def test_pool_did_not_start_at_startup(self):        x = X(self.app)        x.obj.restart_count = 0        x.obj.pool.did_start_ok.return_value = False        with pytest.raises(WorkerLostError):            asynloop(*x.args)    def test_setup_heartbeat(self):        x = X(self.app, heartbeat=10)        x.hub.timer.call_repeatedly = Mock(name='x.hub.call_repeatedly()')        x.blueprint.state = CLOSE        asynloop(*x.args)        x.consumer.consume.assert_called_with()        x.obj.on_ready.assert_called_with()        x.hub.timer.call_repeatedly.assert_called_with(            10 / 2.0, x.connection.heartbeat_check, (2.0,),        )    def task_context(self, sig, **kwargs):        x, on_task = get_task_callback(self.app, **kwargs)        message = self.task_message_from_sig(self.app, sig)        strategy = x.obj.strategies[sig.task] = Mock(name='strategy')        return x, on_task, message, strategy    def test_on_task_received(self):        x, on_task, msg, strategy = self.task_context(self.add.s(2, 2))        on_task(msg)        strategy.assert_called_with(            msg, None,            PromiseEqual(x._consumer.call_soon, msg.ack_log_error),            PromiseEqual(x._consumer.call_soon, msg.reject_log_error), [],        )    def test_on_task_received_executes_on_task_message(self):        cbs = [Mock(), Mock(), Mock()]        x, on_task, msg, strategy = self.task_context(            self.add.s(2, 2), on_task_message=cbs,        )        on_task(msg)        strategy.assert_called_with(            msg, None,            PromiseEqual(x._consumer.call_soon, msg.ack_log_error),            PromiseEqual(x._consumer.call_soon, msg.reject_log_error),            cbs,        )    def test_on_task_message_missing_name(self):        x, on_task, msg, strategy = self.task_context(self.add.s(2, 2))        msg.headers.pop('task')        on_task(msg)        x.on_unknown_message.assert_called_with(msg.decode(), msg)    def test_on_task_pool_raises(self):        x, on_task, msg, strategy = self.task_context(self.add.s(2, 2))        strategy.side_effect = ValueError()        with pytest.raises(ValueError):            on_task(msg)    def test_on_task_InvalidTaskError(self):        x, on_task, msg, strategy = self.task_context(self.add.s(2, 2))        exc = strategy.side_effect = InvalidTaskError()        on_task(msg)        x.on_invalid_task.assert_called_with(None, msg, exc)    def test_should_terminate(self):        x = X(self.app)        # XXX why aren't the errors propagated?!?        state.should_terminate = True        try:            with pytest.raises(WorkerTerminate):                asynloop(*x.args)        finally:            state.should_terminate = None    def test_should_terminate_hub_close_raises(self):        x = X(self.app)        # XXX why aren't the errors propagated?!?        state.should_terminate = EX_FAILURE        x.hub.close.side_effect = MemoryError()        try:            with pytest.raises(WorkerTerminate):                asynloop(*x.args)        finally:            state.should_terminate = None    def test_should_stop(self):        x = X(self.app)        state.should_stop = 303        try:            with pytest.raises(WorkerShutdown):                asynloop(*x.args)        finally:            state.should_stop = None    def test_updates_qos(self):        x = X(self.app)        x.qos.prev = 3        x.qos.value = 3        x.hub.on_tick.add(x.closer(mod=2))        x.hub.timer._queue = [1]        asynloop(*x.args)        x.qos.update.assert_not_called()        x = X(self.app)        x.qos.prev = 1        x.qos.value = 6        x.hub.on_tick.add(x.closer(mod=2))        asynloop(*x.args)        x.qos.update.assert_called_with()        x.hub.fire_timers.assert_called_with(propagate=(socket.error,))    def test_poll_empty(self):        x = X(self.app)        x.hub.readers = {6: Mock()}        x.hub.timer._queue = [1]        x.close_then_error(x.hub.poller.poll)        x.hub.fire_timers.return_value = 33.37        poller = x.hub.poller        poller.poll.return_value = []        with pytest.raises(socket.error):            asynloop(*x.args)        poller.poll.assert_called_with(33.37)    def test_poll_readable(self):        x = X(self.app)        reader = Mock(name='reader')        x.hub.add_reader(6, reader, 6)        x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), mod=4))        poller = x.hub.poller        poller.poll.return_value = [(6, READ)]        with pytest.raises(socket.error):            asynloop(*x.args)        reader.assert_called_with(6)        poller.poll.assert_called()    def test_poll_readable_raises_Empty(self):        x = X(self.app)        reader = Mock(name='reader')        x.hub.add_reader(6, reader, 6)        x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))        poller = x.hub.poller        poller.poll.return_value = [(6, READ)]        reader.side_effect = Empty()        with pytest.raises(socket.error):            asynloop(*x.args)        reader.assert_called_with(6)        poller.poll.assert_called()    def test_poll_writable(self):        x = X(self.app)        writer = Mock(name='writer')        x.hub.add_writer(6, writer, 6)        x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))        poller = x.hub.poller        poller.poll.return_value = [(6, WRITE)]        with pytest.raises(socket.error):            asynloop(*x.args)        writer.assert_called_with(6)        poller.poll.assert_called()    def test_poll_writable_none_registered(self):        x = X(self.app)        writer = Mock(name='writer')        x.hub.add_writer(6, writer, 6)        x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))        poller = x.hub.poller        poller.poll.return_value = [(7, WRITE)]        with pytest.raises(socket.error):            asynloop(*x.args)        poller.poll.assert_called()    def test_poll_unknown_event(self):        x = X(self.app)        writer = Mock(name='reader')        x.hub.add_writer(6, writer, 6)        x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))        poller = x.hub.poller        poller.poll.return_value = [(6, 0)]        with pytest.raises(socket.error):            asynloop(*x.args)        poller.poll.assert_called()    def test_poll_keep_draining_disabled(self):        x = X(self.app)        x.hub.writers = {6: Mock()}        poll = x.hub.poller.poll        def se(*args, **kwargs):            poll.side_effect = socket.error()        poll.side_effect = se        poller = x.hub.poller        poll.return_value = [(6, 0)]        with pytest.raises(socket.error):            asynloop(*x.args)        poller.poll.assert_called()    def test_poll_err_writable(self):        x = X(self.app)        writer = Mock(name='writer')        x.hub.add_writer(6, writer, 6, 48)        x.hub.on_tick.add(x.close_then_error(Mock(), 2))        poller = x.hub.poller        poller.poll.return_value = [(6, ERR)]        with pytest.raises(socket.error):            asynloop(*x.args)        writer.assert_called_with(6, 48)        poller.poll.assert_called()    def test_poll_write_generator(self):        x = X(self.app)        x.hub.remove = Mock(name='hub.remove()')        def Gen():            yield 1            yield 2        gen = Gen()        x.hub.add_writer(6, gen)        x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))        x.hub.poller.poll.return_value = [(6, WRITE)]        with pytest.raises(socket.error):            asynloop(*x.args)        assert gen.gi_frame.f_lasti != -1        x.hub.remove.assert_not_called()    def test_poll_write_generator_stopped(self):        x = X(self.app)        def Gen():            raise StopIteration()            yield        gen = Gen()        x.hub.add_writer(6, gen)        x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))        x.hub.poller.poll.return_value = [(6, WRITE)]        x.hub.remove = Mock(name='hub.remove()')        with pytest.raises(socket.error):            asynloop(*x.args)        assert gen.gi_frame is None    def test_poll_write_generator_raises(self):        x = X(self.app)        def Gen():            raise ValueError('foo')            yield        gen = Gen()        x.hub.add_writer(6, gen)        x.hub.remove = Mock(name='hub.remove()')        x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))        x.hub.poller.poll.return_value = [(6, WRITE)]        with pytest.raises(ValueError):            asynloop(*x.args)        assert gen.gi_frame is None        x.hub.remove.assert_called_with(6)    def test_poll_err_readable(self):        x = X(self.app)        reader = Mock(name='reader')        x.hub.add_reader(6, reader, 6, 24)        x.hub.on_tick.add(x.close_then_error(Mock(), 2))        poller = x.hub.poller        poller.poll.return_value = [(6, ERR)]        with pytest.raises(socket.error):            asynloop(*x.args)        reader.assert_called_with(6, 24)        poller.poll.assert_called()    def test_poll_raises_ValueError(self):        x = X(self.app)        x.hub.readers = {6: Mock()}        poller = x.hub.poller        x.close_then_error(poller.poll, exc=ValueError)        asynloop(*x.args)        poller.poll.assert_called()class test_synloop:    def test_timeout_ignored(self):        x = X(self.app)        x.timeout_then_error(x.connection.drain_events)        with pytest.raises(socket.error):            synloop(*x.args)        assert x.connection.drain_events.call_count == 2    def test_updates_qos_when_changed(self):        x = X(self.app)        x.qos.prev = 2        x.qos.value = 2        x.timeout_then_error(x.connection.drain_events)        with pytest.raises(socket.error):            synloop(*x.args)        x.qos.update.assert_not_called()        x.qos.value = 4        x.timeout_then_error(x.connection.drain_events)        with pytest.raises(socket.error):            synloop(*x.args)        x.qos.update.assert_called_with()    def test_ignores_socket_errors_when_closed(self):        x = X(self.app)        x.close_then_error(x.connection.drain_events)        assert synloop(*x.args) is Noneclass test_quick_drain:    def setup(self):        self.connection = Mock(name='connection')    def test_drain(self):        _quick_drain(self.connection, timeout=33.3)        self.connection.drain_events.assert_called_with(timeout=33.3)    def test_drain_error(self):        exc = KeyError()        exc.errno = 313        self.connection.drain_events.side_effect = exc        with pytest.raises(KeyError):            _quick_drain(self.connection, timeout=33.3)    def test_drain_error_EAGAIN(self):        exc = KeyError()        exc.errno = errno.EAGAIN        self.connection.drain_events.side_effect = exc        _quick_drain(self.connection, timeout=33.3)
 |