| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602 | from __future__ import absolute_import, unicode_literalsimport errnoimport socketfrom collections import dequeimport pytestfrom billiard.exceptions import RestartFreqExceededfrom case import ContextMock, Mock, call, patch, skipfrom celery.utils.collections import LimitedSetfrom celery.worker.consumer.agent import Agentfrom celery.worker.consumer.consumer import (CLOSE, TERMINATE, Consumer,                                             dump_body)from celery.worker.consumer.gossip import Gossipfrom celery.worker.consumer.heart import Heartfrom celery.worker.consumer.mingle import Minglefrom celery.worker.consumer.tasks import Tasksclass test_Consumer:    def get_consumer(self, no_hub=False, **kwargs):        consumer = Consumer(            on_task_request=Mock(),            init_callback=Mock(),            pool=Mock(),            app=self.app,            timer=Mock(),            controller=Mock(),            hub=None if no_hub else Mock(),            **kwargs        )        consumer.blueprint = Mock(name='blueprint')        consumer._restart_state = Mock(name='_restart_state')        consumer.connection = _amqp_connection()        consumer.connection_errors = (socket.error, OSError,)        consumer.conninfo = consumer.connection        return consumer    def test_repr(self):        assert repr(self.get_consumer())    def test_taskbuckets_defaultdict(self):        c = self.get_consumer()        assert c.task_buckets['fooxasdwx.wewe'] is None    @skip.if_python3(reason='buffer type not available')    def test_dump_body_buffer(self):        msg = Mock()        msg.body = 'str'        assert dump_body(msg, buffer(msg.body))  # noqa: F821    def test_sets_heartbeat(self):        c = self.get_consumer(amqheartbeat=10)        assert c.amqheartbeat == 10        self.app.conf.broker_heartbeat = 20        c = self.get_consumer(amqheartbeat=None)        assert c.amqheartbeat == 20    def test_gevent_bug_disables_connection_timeout(self):        with patch('celery.worker.consumer.consumer._detect_environment') as d:            d.return_value = 'gevent'            self.app.conf.broker_connection_timeout = 33.33            self.get_consumer()            assert self.app.conf.broker_connection_timeout is None    def test_limit_moved_to_pool(self):        with patch('celery.worker.consumer.consumer.task_reserved') as reserv:            c = self.get_consumer()            c.on_task_request = Mock(name='on_task_request')            request = Mock(name='request')            c._limit_move_to_pool(request)            reserv.assert_called_with(request)            c.on_task_request.assert_called_with(request)    def test_update_prefetch_count(self):        c = self.get_consumer()        c._update_qos_eventually = Mock(name='update_qos')        c.initial_prefetch_count = None        c.pool.num_processes = None        c.prefetch_multiplier = 10        assert c._update_prefetch_count(1) is None        c.initial_prefetch_count = 10        c.pool.num_processes = 10        c._update_prefetch_count(8)        c._update_qos_eventually.assert_called_with(8)        assert c.initial_prefetch_count == 10 * 10    def test_flush_events(self):        c = self.get_consumer()        c.event_dispatcher = None        c._flush_events()        c.event_dispatcher = Mock(name='evd')        c._flush_events()        c.event_dispatcher.flush.assert_called_with()    def test_on_send_event_buffered(self):        c = self.get_consumer()        c.hub = None        c.on_send_event_buffered()        c.hub = Mock(name='hub')        c.on_send_event_buffered()        c.hub._ready.add.assert_called_with(c._flush_events)    def test_limit_task(self):        c = self.get_consumer()        c.timer = Mock()        bucket = Mock()        request = Mock()        bucket.can_consume.return_value = True        bucket.contents = deque()        c._limit_task(request, bucket, 3)        bucket.can_consume.assert_called_with(3)        bucket.expected_time.assert_called_with(3)        c.timer.call_after.assert_called_with(            bucket.expected_time(), c._on_bucket_wakeup, (bucket, 3),            priority=c._limit_order,        )        bucket.can_consume.return_value = False        bucket.expected_time.return_value = 3.33        limit_order = c._limit_order        c._limit_task(request, bucket, 4)        assert c._limit_order == limit_order + 1        bucket.can_consume.assert_called_with(4)        c.timer.call_after.assert_called_with(            3.33, c._on_bucket_wakeup, (bucket, 4),            priority=c._limit_order,        )        bucket.expected_time.assert_called_with(4)    def test_start_blueprint_raises_EMFILE(self):        c = self.get_consumer()        exc = c.blueprint.start.side_effect = OSError()        exc.errno = errno.EMFILE        with pytest.raises(OSError):            c.start()    def test_max_restarts_exceeded(self):        c = self.get_consumer()        def se(*args, **kwargs):            c.blueprint.state = CLOSE            raise RestartFreqExceeded()        c._restart_state.step.side_effect = se        c.blueprint.start.side_effect = socket.error()        with patch('celery.worker.consumer.consumer.sleep') as sleep:            c.start()            sleep.assert_called_with(1)    def test_do_not_restart_when_closed(self):        c = self.get_consumer()        c.blueprint.state = None        def bp_start(*args, **kwargs):            c.blueprint.state = CLOSE        c.blueprint.start.side_effect = bp_start        with patch('celery.worker.consumer.consumer.sleep'):            c.start()        c.blueprint.start.assert_called_once_with(c)    def test_do_not_restart_when_terminated(self):        c = self.get_consumer()        c.blueprint.state = None        def bp_start(*args, **kwargs):            c.blueprint.state = TERMINATE        c.blueprint.start.side_effect = bp_start        with patch('celery.worker.consumer.consumer.sleep'):            c.start()        c.blueprint.start.assert_called_once_with(c)    def test_no_retry_raises_error(self):        self.app.conf.broker_connection_retry = False        c = self.get_consumer()        c.blueprint.start.side_effect = socket.error()        with pytest.raises(socket.error):            c.start()    def _closer(self, c):        def se(*args, **kwargs):            c.blueprint.state = CLOSE        return se    def test_collects_at_restart(self):        c = self.get_consumer()        c.connection.collect.side_effect = MemoryError()        c.blueprint.start.side_effect = socket.error()        c.blueprint.restart.side_effect = self._closer(c)        c.start()        c.connection.collect.assert_called_with()    def test_register_with_event_loop(self):        c = self.get_consumer()        c.register_with_event_loop(Mock(name='loop'))    def test_on_close_clears_semaphore_timer_and_reqs(self):        with patch('celery.worker.consumer.consumer.reserved_requests') as res:            c = self.get_consumer()            c.on_close()            c.controller.semaphore.clear.assert_called_with()            c.timer.clear.assert_called_with()            res.clear.assert_called_with()            c.pool.flush.assert_called_with()            c.controller = None            c.timer = None            c.pool = None            c.on_close()    def test_connect_error_handler(self):        self.app._connection = _amqp_connection()        conn = self.app._connection.return_value        c = self.get_consumer()        assert c.connect()        conn.ensure_connection.assert_called()        errback = conn.ensure_connection.call_args[0][0]        errback(Mock(), 0)class test_Heart:    def test_start(self):        c = Mock()        c.timer = Mock()        c.event_dispatcher = Mock()        with patch('celery.worker.heartbeat.Heart') as hcls:            h = Heart(c)            assert h.enabled            assert h.heartbeat_interval is None            assert c.heart is None            h.start(c)            assert c.heart            hcls.assert_called_with(c.timer, c.event_dispatcher,                                    h.heartbeat_interval)            c.heart.start.assert_called_with()    def test_start_heartbeat_interval(self):        c = Mock()        c.timer = Mock()        c.event_dispatcher = Mock()        with patch('celery.worker.heartbeat.Heart') as hcls:            h = Heart(c, False, 20)            assert h.enabled            assert h.heartbeat_interval == 20            assert c.heart is None            h.start(c)            assert c.heart            hcls.assert_called_with(c.timer, c.event_dispatcher,                                    h.heartbeat_interval)            c.heart.start.assert_called_with()class test_Tasks:    def test_stop(self):        c = Mock()        tasks = Tasks(c)        assert c.task_consumer is None        assert c.qos is None        c.task_consumer = Mock()        tasks.stop(c)    def test_stop_already_stopped(self):        c = Mock()        tasks = Tasks(c)        tasks.stop(c)class test_Agent:    def test_start(self):        c = Mock()        agent = Agent(c)        agent.instantiate = Mock()        agent.agent_cls = 'foo:Agent'        assert agent.create(c) is not None        agent.instantiate.assert_called_with(agent.agent_cls, c.connection)class test_Mingle:    def test_start_no_replies(self):        c = Mock()        c.app.connection_for_read = _amqp_connection()        mingle = Mingle(c)        I = c.app.control.inspect.return_value = Mock()        I.hello.return_value = {}        mingle.start(c)    def test_start(self):        c = Mock()        c.app.connection_for_read = _amqp_connection()        mingle = Mingle(c)        assert mingle.enabled        Aig = LimitedSet()        Big = LimitedSet()        Aig.add('Aig-1')        Aig.add('Aig-2')        Big.add('Big-1')        I = c.app.control.inspect.return_value = Mock()        I.hello.return_value = {            'A@example.com': {                'clock': 312,                'revoked': Aig._data,            },            'B@example.com': {                'clock': 29,                'revoked': Big._data,            },            'C@example.com': {                'error': 'unknown method',            },        }        our_revoked = c.controller.state.revoked = LimitedSet()        mingle.start(c)        I.hello.assert_called_with(c.hostname, our_revoked._data)        c.app.clock.adjust.assert_has_calls([            call(312), call(29),        ], any_order=True)        assert 'Aig-1' in our_revoked        assert 'Aig-2' in our_revoked        assert 'Big-1' in our_revokeddef _amqp_connection():    connection = ContextMock(name='Connection')    connection.return_value = ContextMock(name='connection')    connection.return_value.transport.driver_type = 'amqp'    return connectionclass test_Gossip:    def test_init(self):        c = self.Consumer()        c.app.connection_for_read = _amqp_connection()        g = Gossip(c)        assert g.enabled        assert c.gossip is g    def test_callbacks(self):        c = self.Consumer()        c.app.connection_for_read = _amqp_connection()        g = Gossip(c)        on_node_join = Mock(name='on_node_join')        on_node_join2 = Mock(name='on_node_join2')        on_node_leave = Mock(name='on_node_leave')        on_node_lost = Mock(name='on.node_lost')        g.on.node_join.add(on_node_join)        g.on.node_join.add(on_node_join2)        g.on.node_leave.add(on_node_leave)        g.on.node_lost.add(on_node_lost)        worker = Mock(name='worker')        g.on_node_join(worker)        on_node_join.assert_called_with(worker)        on_node_join2.assert_called_with(worker)        g.on_node_leave(worker)        on_node_leave.assert_called_with(worker)        g.on_node_lost(worker)        on_node_lost.assert_called_with(worker)    def test_election(self):        c = self.Consumer()        c.app.connection_for_read = _amqp_connection()        g = Gossip(c)        g.start(c)        g.election('id', 'topic', 'action')        assert g.consensus_replies['id'] == []        g.dispatcher.send.assert_called_with(            'worker-elect', id='id', topic='topic', cver=1, action='action',        )    def test_call_task(self):        c = self.Consumer()        c.app.connection_for_read = _amqp_connection()        g = Gossip(c)        g.start(c)        signature = g.app.signature = Mock(name='app.signature')        task = Mock()        g.call_task(task)        signature.assert_called_with(task)        signature.return_value.apply_async.assert_called_with()        signature.return_value.apply_async.side_effect = MemoryError()        with patch('celery.worker.consumer.gossip.logger') as logger:            g.call_task(task)            logger.exception.assert_called()    def Event(self, id='id', clock=312,              hostname='foo@example.com', pid=4312,              topic='topic', action='action', cver=1):        return {            'id': id,            'clock': clock,            'hostname': hostname,            'pid': pid,            'topic': topic,            'action': action,            'cver': cver,        }    def test_on_elect(self):        c = self.Consumer()        c.app.connection_for_read = _amqp_connection()        g = Gossip(c)        g.start(c)        event = self.Event('id1')        g.on_elect(event)        in_heap = g.consensus_requests['id1']        assert in_heap        g.dispatcher.send.assert_called_with('worker-elect-ack', id='id1')        event.pop('clock')        with patch('celery.worker.consumer.gossip.logger') as logger:            g.on_elect(event)            logger.exception.assert_called()    def Consumer(self, hostname='foo@x.com', pid=4312):        c = Mock()        c.app.connection = _amqp_connection()        c.hostname = hostname        c.pid = pid        return c    def setup_election(self, g, c):        g.start(c)        g.clock = self.app.clock        assert 'idx' not in g.consensus_replies        assert g.on_elect_ack({'id': 'idx'}) is None        g.state.alive_workers.return_value = [            'foo@x.com', 'bar@x.com', 'baz@x.com',        ]        g.consensus_replies['id1'] = []        g.consensus_requests['id1'] = []        e1 = self.Event('id1', 1, 'foo@x.com')        e2 = self.Event('id1', 2, 'bar@x.com')        e3 = self.Event('id1', 3, 'baz@x.com')        g.on_elect(e1)        g.on_elect(e2)        g.on_elect(e3)        assert len(g.consensus_requests['id1']) == 3        with patch('celery.worker.consumer.gossip.info'):            g.on_elect_ack(e1)            assert len(g.consensus_replies['id1']) == 1            g.on_elect_ack(e2)            assert len(g.consensus_replies['id1']) == 2            g.on_elect_ack(e3)            with pytest.raises(KeyError):                g.consensus_replies['id1']    def test_on_elect_ack_win(self):        c = self.Consumer(hostname='foo@x.com')  # I will win        c.app.connection_for_read = _amqp_connection()        g = Gossip(c)        handler = g.election_handlers['topic'] = Mock()        self.setup_election(g, c)        handler.assert_called_with('action')    def test_on_elect_ack_lose(self):        c = self.Consumer(hostname='bar@x.com')  # I will lose        c.app.connection_for_read = _amqp_connection()        g = Gossip(c)        handler = g.election_handlers['topic'] = Mock()        self.setup_election(g, c)        handler.assert_not_called()    def test_on_elect_ack_win_but_no_action(self):        c = self.Consumer(hostname='foo@x.com')  # I will win        c.app.connection_for_read = _amqp_connection()        g = Gossip(c)        g.election_handlers = {}        with patch('celery.worker.consumer.gossip.logger') as logger:            self.setup_election(g, c)            logger.exception.assert_called()    def test_on_node_join(self):        c = self.Consumer()        c.app.connection_for_read = _amqp_connection()        g = Gossip(c)        with patch('celery.worker.consumer.gossip.debug') as debug:            g.on_node_join(c)            debug.assert_called_with('%s joined the party', 'foo@x.com')    def test_on_node_leave(self):        c = self.Consumer()        c.app.connection_for_read = _amqp_connection()        g = Gossip(c)        with patch('celery.worker.consumer.gossip.debug') as debug:            g.on_node_leave(c)            debug.assert_called_with('%s left', 'foo@x.com')    def test_on_node_lost(self):        c = self.Consumer()        c.app.connection_for_read = _amqp_connection()        g = Gossip(c)        with patch('celery.worker.consumer.gossip.info') as info:            g.on_node_lost(c)            info.assert_called_with('missed heartbeat from %s', 'foo@x.com')    def test_register_timer(self):        c = self.Consumer()        c.app.connection_for_read = _amqp_connection()        g = Gossip(c)        g.register_timer()        c.timer.call_repeatedly.assert_called_with(g.interval, g.periodic)        tref = g._tref        g.register_timer()        tref.cancel.assert_called_with()    def test_periodic(self):        c = self.Consumer()        c.app.connection_for_read = _amqp_connection()        g = Gossip(c)        g.on_node_lost = Mock()        state = g.state = Mock()        worker = Mock()        state.workers = {'foo': worker}        worker.alive = True        worker.hostname = 'foo'        g.periodic()        worker.alive = False        g.periodic()        g.on_node_lost.assert_called_with(worker)        with pytest.raises(KeyError):            state.workers['foo']    def test_on_message__task(self):        c = self.Consumer()        c.app.connection_for_read = _amqp_connection()        g = Gossip(c)        assert g.enabled        message = Mock(name='message')        message.delivery_info = {'routing_key': 'task.failed'}        g.on_message(Mock(name='prepare'), message)    def test_on_message(self):        c = self.Consumer()        c.app.connection_for_read = _amqp_connection()        g = Gossip(c)        assert g.enabled        prepare = Mock()        prepare.return_value = 'worker-online', {}        c.app.events.State.assert_called_with(            on_node_join=g.on_node_join,            on_node_leave=g.on_node_leave,            max_tasks_in_memory=1,        )        g.update_state = Mock()        worker = Mock()        g.on_node_join = Mock()        g.on_node_leave = Mock()        g.update_state.return_value = worker, 1        message = Mock()        message.delivery_info = {'routing_key': 'worker-online'}        message.headers = {'hostname': 'other'}        handler = g.event_handlers['worker-online'] = Mock()        g.on_message(prepare, message)        handler.assert_called_with(message.payload)        g.event_handlers = {}        g.on_message(prepare, message)        message.delivery_info = {'routing_key': 'worker-offline'}        prepare.return_value = 'worker-offline', {}        g.on_message(prepare, message)        message.delivery_info = {'routing_key': 'worker-baz'}        prepare.return_value = 'worker-baz', {}        g.update_state.return_value = worker, 0        g.on_message(prepare, message)        message.headers = {'hostname': g.hostname}        g.on_message(prepare, message)        g.clock.forward.assert_called_with()
 |