123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565 |
- import errno
- import pytest
- import socket
- from collections import deque
- from case import ContextMock, Mock, call, patch
- from billiard.exceptions import RestartFreqExceeded
- from celery.worker.consumer.agent import Agent
- from celery.worker.consumer.consumer import CLOSE, Consumer
- from celery.worker.consumer.gossip import Gossip
- from celery.worker.consumer.heart import Heart
- from celery.worker.consumer.mingle import Mingle
- from celery.worker.consumer.tasks import Tasks
- from celery.utils.collections import LimitedSet
- class test_Consumer:
- def get_consumer(self, no_hub=False, **kwargs):
- consumer = Consumer(
- on_task_request=Mock(),
- init_callback=Mock(),
- pool=Mock(),
- app=self.app,
- timer=Mock(),
- controller=Mock(),
- hub=None if no_hub else Mock(),
- **kwargs
- )
- consumer.blueprint = Mock(name='blueprint')
- consumer._restart_state = Mock(name='_restart_state')
- consumer.connection = _amqp_connection()
- consumer.connection_errors = (socket.error, OSError,)
- consumer.conninfo = consumer.connection
- return consumer
- def test_repr(self):
- assert repr(self.get_consumer())
- def test_taskbuckets_defaultdict(self):
- c = self.get_consumer()
- assert c.task_buckets['fooxasdwx.wewe'] is None
- def test_sets_heartbeat(self):
- c = self.get_consumer(amqheartbeat=10)
- assert c.amqheartbeat == 10
- self.app.conf.broker_heartbeat = 20
- c = self.get_consumer(amqheartbeat=None)
- assert c.amqheartbeat == 20
- def test_gevent_bug_disables_connection_timeout(self):
- with patch('celery.worker.consumer.consumer._detect_environment') as d:
- d.return_value = 'gevent'
- self.app.conf.broker_connection_timeout = 33.33
- self.get_consumer()
- assert self.app.conf.broker_connection_timeout is None
- def test_limit_moved_to_pool(self):
- with patch('celery.worker.consumer.consumer.task_reserved') as reserv:
- c = self.get_consumer()
- c.on_task_request = Mock(name='on_task_request')
- request = Mock(name='request')
- c._limit_move_to_pool(request)
- reserv.assert_called_with(request)
- c.on_task_request.assert_called_with(request)
- def test_update_prefetch_count(self):
- c = self.get_consumer()
- c._update_qos_eventually = Mock(name='update_qos')
- c.initial_prefetch_count = None
- c.pool.num_processes = None
- c.prefetch_multiplier = 10
- assert c._update_prefetch_count(1) is None
- c.initial_prefetch_count = 10
- c.pool.num_processes = 10
- c._update_prefetch_count(8)
- c._update_qos_eventually.assert_called_with(8)
- assert c.initial_prefetch_count == 10 * 10
- def test_flush_events(self):
- c = self.get_consumer()
- c.event_dispatcher = None
- c._flush_events()
- c.event_dispatcher = Mock(name='evd')
- c._flush_events()
- c.event_dispatcher.flush.assert_called_with()
- def test_on_send_event_buffered(self):
- c = self.get_consumer()
- c.hub = None
- c.on_send_event_buffered()
- c.hub = Mock(name='hub')
- c.on_send_event_buffered()
- c.hub._ready.add.assert_called_with(c._flush_events)
- def test_limit_task(self):
- c = self.get_consumer()
- c.timer = Mock()
- bucket = Mock()
- request = Mock()
- bucket.can_consume.return_value = True
- bucket.contents = deque()
- c._limit_task(request, bucket, 3)
- bucket.can_consume.assert_called_with(3)
- bucket.expected_time.assert_called_with(3)
- c.timer.call_after.assert_called_with(
- bucket.expected_time(), c._on_bucket_wakeup, (bucket, 3),
- priority=c._limit_order,
- )
- bucket.can_consume.return_value = False
- bucket.expected_time.return_value = 3.33
- limit_order = c._limit_order
- c._limit_task(request, bucket, 4)
- assert c._limit_order == limit_order + 1
- bucket.can_consume.assert_called_with(4)
- c.timer.call_after.assert_called_with(
- 3.33, c._on_bucket_wakeup, (bucket, 4),
- priority=c._limit_order,
- )
- bucket.expected_time.assert_called_with(4)
- def test_start_blueprint_raises_EMFILE(self):
- c = self.get_consumer()
- exc = c.blueprint.start.side_effect = OSError()
- exc.errno = errno.EMFILE
- with pytest.raises(OSError):
- c.start()
- def test_max_restarts_exceeded(self):
- c = self.get_consumer()
- def se(*args, **kwargs):
- c.blueprint.state = CLOSE
- raise RestartFreqExceeded()
- c._restart_state.step.side_effect = se
- c.blueprint.start.side_effect = socket.error()
- with patch('celery.worker.consumer.consumer.sleep') as sleep:
- c.start()
- sleep.assert_called_with(1)
- def test_no_retry_raises_error(self):
- self.app.conf.broker_connection_retry = False
- c = self.get_consumer()
- c.blueprint.start.side_effect = socket.error()
- with pytest.raises(socket.error):
- c.start()
- def _closer(self, c):
- def se(*args, **kwargs):
- c.blueprint.state = CLOSE
- return se
- def test_collects_at_restart(self):
- c = self.get_consumer()
- c.connection.collect.side_effect = MemoryError()
- c.blueprint.start.side_effect = socket.error()
- c.blueprint.restart.side_effect = self._closer(c)
- c.start()
- c.connection.collect.assert_called_with()
- def test_register_with_event_loop(self):
- c = self.get_consumer()
- c.register_with_event_loop(Mock(name='loop'))
- def test_on_close_clears_semaphore_timer_and_reqs(self):
- with patch('celery.worker.consumer.consumer.reserved_requests') as res:
- c = self.get_consumer()
- c.on_close()
- c.controller.semaphore.clear.assert_called_with()
- c.timer.clear.assert_called_with()
- res.clear.assert_called_with()
- c.pool.flush.assert_called_with()
- c.controller = None
- c.timer = None
- c.pool = None
- c.on_close()
- def test_connect_error_handler(self):
- self.app._connection = _amqp_connection()
- conn = self.app._connection.return_value
- c = self.get_consumer()
- assert c.connect()
- conn.ensure_connection.assert_called()
- errback = conn.ensure_connection.call_args[0][0]
- errback(Mock(), 0)
- class test_Heart:
- def test_start(self):
- c = Mock()
- c.timer = Mock()
- c.event_dispatcher = Mock()
- with patch('celery.worker.heartbeat.Heart') as hcls:
- h = Heart(c)
- assert h.enabled
- assert h.heartbeat_interval is None
- assert c.heart is None
- h.start(c)
- assert c.heart
- hcls.assert_called_with(c.timer, c.event_dispatcher,
- h.heartbeat_interval)
- c.heart.start.assert_called_with()
- def test_start_heartbeat_interval(self):
- c = Mock()
- c.timer = Mock()
- c.event_dispatcher = Mock()
- with patch('celery.worker.heartbeat.Heart') as hcls:
- h = Heart(c, False, 20)
- assert h.enabled
- assert h.heartbeat_interval == 20
- assert c.heart is None
- h.start(c)
- assert c.heart
- hcls.assert_called_with(c.timer, c.event_dispatcher,
- h.heartbeat_interval)
- c.heart.start.assert_called_with()
- class test_Tasks:
- def test_stop(self):
- c = Mock()
- tasks = Tasks(c)
- assert c.task_consumer is None
- assert c.qos is None
- c.task_consumer = Mock()
- tasks.stop(c)
- def test_stop_already_stopped(self):
- c = Mock()
- tasks = Tasks(c)
- tasks.stop(c)
- class test_Agent:
- def test_start(self):
- c = Mock()
- agent = Agent(c)
- agent.instantiate = Mock()
- agent.agent_cls = 'foo:Agent'
- assert agent.create(c) is not None
- agent.instantiate.assert_called_with(agent.agent_cls, c.connection)
- class test_Mingle:
- def test_start_no_replies(self):
- c = Mock()
- c.app.connection_for_read = _amqp_connection()
- mingle = Mingle(c)
- I = c.app.control.inspect.return_value = Mock()
- I.hello.return_value = {}
- mingle.start(c)
- def test_start(self):
- c = Mock()
- c.app.connection_for_read = _amqp_connection()
- mingle = Mingle(c)
- assert mingle.enabled
- Aig = LimitedSet()
- Big = LimitedSet()
- Aig.add('Aig-1')
- Aig.add('Aig-2')
- Big.add('Big-1')
- I = c.app.control.inspect.return_value = Mock()
- I.hello.return_value = {
- 'A@example.com': {
- 'clock': 312,
- 'revoked': Aig._data,
- },
- 'B@example.com': {
- 'clock': 29,
- 'revoked': Big._data,
- },
- 'C@example.com': {
- 'error': 'unknown method',
- },
- }
- our_revoked = c.controller.state.revoked = LimitedSet()
- mingle.start(c)
- I.hello.assert_called_with(c.hostname, our_revoked._data)
- c.app.clock.adjust.assert_has_calls([
- call(312), call(29),
- ], any_order=True)
- assert 'Aig-1' in our_revoked
- assert 'Aig-2' in our_revoked
- assert 'Big-1' in our_revoked
- def _amqp_connection():
- connection = ContextMock(name='Connection')
- connection.return_value = ContextMock(name='connection')
- connection.return_value.transport.driver_type = 'amqp'
- return connection
- class test_Gossip:
- def test_init(self):
- c = self.Consumer()
- c.app.connection_for_read = _amqp_connection()
- g = Gossip(c)
- assert g.enabled
- assert c.gossip is g
- def test_callbacks(self):
- c = self.Consumer()
- c.app.connection_for_read = _amqp_connection()
- g = Gossip(c)
- on_node_join = Mock(name='on_node_join')
- on_node_join2 = Mock(name='on_node_join2')
- on_node_leave = Mock(name='on_node_leave')
- on_node_lost = Mock(name='on.node_lost')
- g.on.node_join.add(on_node_join)
- g.on.node_join.add(on_node_join2)
- g.on.node_leave.add(on_node_leave)
- g.on.node_lost.add(on_node_lost)
- worker = Mock(name='worker')
- g.on_node_join(worker)
- on_node_join.assert_called_with(worker)
- on_node_join2.assert_called_with(worker)
- g.on_node_leave(worker)
- on_node_leave.assert_called_with(worker)
- g.on_node_lost(worker)
- on_node_lost.assert_called_with(worker)
- def test_election(self):
- c = self.Consumer()
- c.app.connection_for_read = _amqp_connection()
- g = Gossip(c)
- g.start(c)
- g.election('id', 'topic', 'action')
- assert g.consensus_replies['id'] == []
- g.dispatcher.send.assert_called_with(
- 'worker-elect', id='id', topic='topic', cver=1, action='action',
- )
- def test_call_task(self):
- c = self.Consumer()
- c.app.connection_for_read = _amqp_connection()
- g = Gossip(c)
- g.start(c)
- signature = g.app.signature = Mock(name='app.signature')
- task = Mock()
- g.call_task(task)
- signature.assert_called_with(task)
- signature.return_value.apply_async.assert_called_with()
- signature.return_value.apply_async.side_effect = MemoryError()
- with patch('celery.worker.consumer.gossip.logger') as logger:
- g.call_task(task)
- logger.exception.assert_called()
- def Event(self, id='id', clock=312,
- hostname='foo@example.com', pid=4312,
- topic='topic', action='action', cver=1):
- return {
- 'id': id,
- 'clock': clock,
- 'hostname': hostname,
- 'pid': pid,
- 'topic': topic,
- 'action': action,
- 'cver': cver,
- }
- def test_on_elect(self):
- c = self.Consumer()
- c.app.connection_for_read = _amqp_connection()
- g = Gossip(c)
- g.start(c)
- event = self.Event('id1')
- g.on_elect(event)
- in_heap = g.consensus_requests['id1']
- assert in_heap
- g.dispatcher.send.assert_called_with('worker-elect-ack', id='id1')
- event.pop('clock')
- with patch('celery.worker.consumer.gossip.logger') as logger:
- g.on_elect(event)
- logger.exception.assert_called()
- def Consumer(self, hostname='foo@x.com', pid=4312):
- c = Mock()
- c.app.connection = _amqp_connection()
- c.hostname = hostname
- c.pid = pid
- return c
- def setup_election(self, g, c):
- g.start(c)
- g.clock = self.app.clock
- assert 'idx' not in g.consensus_replies
- assert g.on_elect_ack({'id': 'idx'}) is None
- g.state.alive_workers.return_value = [
- 'foo@x.com', 'bar@x.com', 'baz@x.com',
- ]
- g.consensus_replies['id1'] = []
- g.consensus_requests['id1'] = []
- e1 = self.Event('id1', 1, 'foo@x.com')
- e2 = self.Event('id1', 2, 'bar@x.com')
- e3 = self.Event('id1', 3, 'baz@x.com')
- g.on_elect(e1)
- g.on_elect(e2)
- g.on_elect(e3)
- assert len(g.consensus_requests['id1']) == 3
- with patch('celery.worker.consumer.gossip.info'):
- g.on_elect_ack(e1)
- assert len(g.consensus_replies['id1']) == 1
- g.on_elect_ack(e2)
- assert len(g.consensus_replies['id1']) == 2
- g.on_elect_ack(e3)
- with pytest.raises(KeyError):
- g.consensus_replies['id1']
- def test_on_elect_ack_win(self):
- c = self.Consumer(hostname='foo@x.com') # I will win
- c.app.connection_for_read = _amqp_connection()
- g = Gossip(c)
- handler = g.election_handlers['topic'] = Mock()
- self.setup_election(g, c)
- handler.assert_called_with('action')
- def test_on_elect_ack_lose(self):
- c = self.Consumer(hostname='bar@x.com') # I will lose
- c.app.connection_for_read = _amqp_connection()
- g = Gossip(c)
- handler = g.election_handlers['topic'] = Mock()
- self.setup_election(g, c)
- handler.assert_not_called()
- def test_on_elect_ack_win_but_no_action(self):
- c = self.Consumer(hostname='foo@x.com') # I will win
- c.app.connection_for_read = _amqp_connection()
- g = Gossip(c)
- g.election_handlers = {}
- with patch('celery.worker.consumer.gossip.logger') as logger:
- self.setup_election(g, c)
- logger.exception.assert_called()
- def test_on_node_join(self):
- c = self.Consumer()
- c.app.connection_for_read = _amqp_connection()
- g = Gossip(c)
- with patch('celery.worker.consumer.gossip.debug') as debug:
- g.on_node_join(c)
- debug.assert_called_with('%s joined the party', 'foo@x.com')
- def test_on_node_leave(self):
- c = self.Consumer()
- c.app.connection_for_read = _amqp_connection()
- g = Gossip(c)
- with patch('celery.worker.consumer.gossip.debug') as debug:
- g.on_node_leave(c)
- debug.assert_called_with('%s left', 'foo@x.com')
- def test_on_node_lost(self):
- c = self.Consumer()
- c.app.connection_for_read = _amqp_connection()
- g = Gossip(c)
- with patch('celery.worker.consumer.gossip.info') as info:
- g.on_node_lost(c)
- info.assert_called_with('missed heartbeat from %s', 'foo@x.com')
- def test_register_timer(self):
- c = self.Consumer()
- c.app.connection_for_read = _amqp_connection()
- g = Gossip(c)
- g.register_timer()
- c.timer.call_repeatedly.assert_called_with(g.interval, g.periodic)
- tref = g._tref
- g.register_timer()
- tref.cancel.assert_called_with()
- def test_periodic(self):
- c = self.Consumer()
- c.app.connection_for_read = _amqp_connection()
- g = Gossip(c)
- g.on_node_lost = Mock()
- state = g.state = Mock()
- worker = Mock()
- state.workers = {'foo': worker}
- worker.alive = True
- worker.hostname = 'foo'
- g.periodic()
- worker.alive = False
- g.periodic()
- g.on_node_lost.assert_called_with(worker)
- with pytest.raises(KeyError):
- state.workers['foo']
- def test_on_message__task(self):
- c = self.Consumer()
- c.app.connection_for_read = _amqp_connection()
- g = Gossip(c)
- assert g.enabled
- message = Mock(name='message')
- message.delivery_info = {'routing_key': 'task.failed'}
- g.on_message(Mock(name='prepare'), message)
- def test_on_message(self):
- c = self.Consumer()
- c.app.connection_for_read = _amqp_connection()
- g = Gossip(c)
- assert g.enabled
- prepare = Mock()
- prepare.return_value = 'worker-online', {}
- c.app.events.State.assert_called_with(
- on_node_join=g.on_node_join,
- on_node_leave=g.on_node_leave,
- max_tasks_in_memory=1,
- )
- g.update_state = Mock()
- worker = Mock()
- g.on_node_join = Mock()
- g.on_node_leave = Mock()
- g.update_state.return_value = worker, 1
- message = Mock()
- message.delivery_info = {'routing_key': 'worker-online'}
- message.headers = {'hostname': 'other'}
- handler = g.event_handlers['worker-online'] = Mock()
- g.on_message(prepare, message)
- handler.assert_called_with(message.payload)
- g.event_handlers = {}
- g.on_message(prepare, message)
- message.delivery_info = {'routing_key': 'worker-offline'}
- prepare.return_value = 'worker-offline', {}
- g.on_message(prepare, message)
- message.delivery_info = {'routing_key': 'worker-baz'}
- prepare.return_value = 'worker-baz', {}
- g.update_state.return_value = worker, 0
- g.on_message(prepare, message)
- message.headers = {'hostname': g.hostname}
- g.on_message(prepare, message)
- g.clock.forward.assert_called_with()
|