123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406 |
- from __future__ import absolute_import
- import socket
- from mock import Mock
- from kombu.async import Hub, READ, WRITE, ERR
- from celery.bootsteps import CLOSE, RUN
- from celery.exceptions import InvalidTaskError, SystemTerminate
- from celery.five import Empty
- from celery.worker import state
- from celery.worker.consumer import Consumer
- from celery.worker.loops import asynloop, synloop
- from celery.tests.case import AppCase, body_from_sig
- class X(object):
- def __init__(self, app, heartbeat=None, on_task_message=None):
- hub = Hub()
- (
- self.obj,
- self.connection,
- self.consumer,
- self.blueprint,
- self.hub,
- self.qos,
- self.heartbeat,
- self.clock,
- ) = self.args = [Mock(name='obj'),
- Mock(name='connection'),
- Mock(name='consumer'),
- Mock(name='blueprint'),
- hub,
- Mock(name='qos'),
- heartbeat,
- Mock(name='clock')]
- self.connection.supports_heartbeats = True
- self.consumer.callbacks = []
- self.obj.strategies = {}
- self.connection.connection_errors = (socket.error, )
- self.hub.readers = {}
- self.hub.writers = {}
- self.hub.consolidate = set()
- self.hub.timer = Mock(name='hub.timer')
- self.hub.timer._queue = [Mock()]
- self.hub.fire_timers = Mock(name='hub.fire_timers')
- self.hub.fire_timers.return_value = 1.7
- self.hub.poller = Mock(name='hub.poller')
- self.hub.close = Mock(name='hub.close()') # asynloop calls hub.close
- self.Hub = self.hub
- self.blueprint.state = RUN
- # need this for create_task_handler
- _consumer = Consumer(Mock(), timer=Mock(), app=app)
- _consumer.on_task_message = on_task_message or []
- self.obj.create_task_handler = _consumer.create_task_handler
- self.on_unknown_message = self.obj.on_unknown_message = Mock(
- name='on_unknown_message',
- )
- _consumer.on_unknown_message = self.on_unknown_message
- self.on_unknown_task = self.obj.on_unknown_task = Mock(
- name='on_unknown_task',
- )
- _consumer.on_unknown_task = self.on_unknown_task
- self.on_invalid_task = self.obj.on_invalid_task = Mock(
- name='on_invalid_task',
- )
- _consumer.on_invalid_task = self.on_invalid_task
- _consumer.strategies = self.obj.strategies
- def timeout_then_error(self, mock):
- def first(*args, **kwargs):
- mock.side_effect = socket.error()
- self.connection.more_to_read = False
- raise socket.timeout()
- mock.side_effect = first
- def close_then_error(self, mock=None, mod=0, exc=None):
- mock = Mock() if mock is None else mock
- def first(*args, **kwargs):
- if not mod or mock.call_count > mod:
- self.close()
- self.connection.more_to_read = False
- raise (socket.error() if exc is None else exc)
- mock.side_effect = first
- return mock
- def close(self, *args, **kwargs):
- self.blueprint.state = CLOSE
- def closer(self, mock=None, mod=0):
- mock = Mock() if mock is None else mock
- def closing(*args, **kwargs):
- if not mod or mock.call_count >= mod:
- self.close()
- mock.side_effect = closing
- return mock
- def get_task_callback(*args, **kwargs):
- x = X(*args, **kwargs)
- x.blueprint.state = CLOSE
- asynloop(*x.args)
- return x, x.consumer.callbacks[0]
- class test_asynloop(AppCase):
- def setup(self):
- @self.app.task(shared=False)
- def add(x, y):
- return x + y
- self.add = add
- def test_setup_heartbeat(self):
- x = X(self.app, heartbeat=10)
- x.hub.call_repeatedly = Mock(name='x.hub.call_repeatedly()')
- x.blueprint.state = CLOSE
- asynloop(*x.args)
- x.consumer.consume.assert_called_with()
- x.obj.on_ready.assert_called_with()
- x.hub.call_repeatedly.assert_called_with(
- 10 / 2.0, x.connection.heartbeat_check, 2.0,
- )
- def task_context(self, sig, **kwargs):
- x, on_task = get_task_callback(self.app, **kwargs)
- body = body_from_sig(self.app, sig)
- message = Mock()
- strategy = x.obj.strategies[sig.task] = Mock()
- return x, on_task, body, message, strategy
- def test_on_task_received(self):
- _, on_task, body, msg, strategy = self.task_context(self.add.s(2, 2))
- on_task(body, msg)
- strategy.assert_called_with(
- msg, body, msg.ack_log_error, msg.reject_log_error, [],
- )
- def test_on_task_received_executes_on_task_message(self):
- cbs = [Mock(), Mock(), Mock()]
- _, on_task, body, msg, strategy = self.task_context(
- self.add.s(2, 2), on_task_message=cbs,
- )
- on_task(body, msg)
- strategy.assert_called_with(
- msg, body, msg.ack_log_error, msg.reject_log_error, cbs,
- )
- def test_on_task_message_missing_name(self):
- x, on_task, body, msg, strategy = self.task_context(self.add.s(2, 2))
- body.pop('task')
- on_task(body, msg)
- x.on_unknown_message.assert_called_with(body, msg)
- def test_on_task_not_registered(self):
- x, on_task, body, msg, strategy = self.task_context(self.add.s(2, 2))
- exc = strategy.side_effect = KeyError(self.add.name)
- on_task(body, msg)
- x.on_unknown_task.assert_called_with(body, msg, exc)
- def test_on_task_InvalidTaskError(self):
- x, on_task, body, msg, strategy = self.task_context(self.add.s(2, 2))
- exc = strategy.side_effect = InvalidTaskError()
- on_task(body, msg)
- x.on_invalid_task.assert_called_with(body, msg, exc)
- def test_should_terminate(self):
- x = X(self.app)
- # XXX why aren't the errors propagated?!?
- state.should_terminate = True
- try:
- with self.assertRaises(SystemTerminate):
- asynloop(*x.args)
- finally:
- state.should_terminate = False
- def test_should_terminate_hub_close_raises(self):
- x = X(self.app)
- # XXX why aren't the errors propagated?!?
- state.should_terminate = True
- x.hub.close.side_effect = MemoryError()
- try:
- with self.assertRaises(SystemTerminate):
- asynloop(*x.args)
- finally:
- state.should_terminate = False
- def test_should_stop(self):
- x = X(self.app)
- state.should_stop = True
- try:
- with self.assertRaises(SystemExit):
- asynloop(*x.args)
- finally:
- state.should_stop = False
- def test_updates_qos(self):
- x = X(self.app)
- x.qos.prev = 3
- x.qos.value = 3
- x.hub.on_tick.add(x.closer(mod=2))
- x.hub.timer._queue = [1]
- asynloop(*x.args)
- self.assertFalse(x.qos.update.called)
- x = X(self.app)
- x.qos.prev = 1
- x.qos.value = 6
- x.hub.on_tick.add(x.closer(mod=2))
- asynloop(*x.args)
- x.qos.update.assert_called_with()
- x.hub.fire_timers.assert_called_with(propagate=(socket.error, ))
- def test_poll_empty(self):
- x = X(self.app)
- x.hub.readers = {6: Mock()}
- x.hub.timer._queue = [1]
- x.close_then_error(x.hub.poller.poll)
- x.hub.fire_timers.return_value = 33.37
- x.hub.poller.poll.return_value = []
- with self.assertRaises(socket.error):
- asynloop(*x.args)
- x.hub.poller.poll.assert_called_with(33.37)
- def test_poll_readable(self):
- x = X(self.app)
- reader = Mock(name='reader')
- x.hub.add_reader(6, reader, 6)
- x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), mod=4))
- x.hub.poller.poll.return_value = [(6, READ)]
- with self.assertRaises(socket.error):
- asynloop(*x.args)
- reader.assert_called_with(6)
- self.assertTrue(x.hub.poller.poll.called)
- def test_poll_readable_raises_Empty(self):
- x = X(self.app)
- reader = Mock(name='reader')
- x.hub.add_reader(6, reader, 6)
- x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
- x.hub.poller.poll.return_value = [(6, READ)]
- reader.side_effect = Empty()
- with self.assertRaises(socket.error):
- asynloop(*x.args)
- reader.assert_called_with(6)
- self.assertTrue(x.hub.poller.poll.called)
- def test_poll_writable(self):
- x = X(self.app)
- writer = Mock(name='writer')
- x.hub.add_writer(6, writer, 6)
- x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
- x.hub.poller.poll.return_value = [(6, WRITE)]
- with self.assertRaises(socket.error):
- asynloop(*x.args)
- writer.assert_called_with(6)
- self.assertTrue(x.hub.poller.poll.called)
- def test_poll_writable_none_registered(self):
- x = X(self.app)
- writer = Mock(name='writer')
- x.hub.add_writer(6, writer, 6)
- x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
- x.hub.poller.poll.return_value = [(7, WRITE)]
- with self.assertRaises(socket.error):
- asynloop(*x.args)
- self.assertTrue(x.hub.poller.poll.called)
- def test_poll_unknown_event(self):
- x = X(self.app)
- writer = Mock(name='reader')
- x.hub.add_writer(6, writer, 6)
- x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
- x.hub.poller.poll.return_value = [(6, 0)]
- with self.assertRaises(socket.error):
- asynloop(*x.args)
- self.assertTrue(x.hub.poller.poll.called)
- def test_poll_keep_draining_disabled(self):
- x = X(self.app)
- x.hub.writers = {6: Mock()}
- poll = x.hub.poller.poll
- def se(*args, **kwargs):
- poll.side_effect = socket.error()
- poll.side_effect = se
- x.hub.poller.poll.return_value = [(6, 0)]
- with self.assertRaises(socket.error):
- asynloop(*x.args)
- self.assertTrue(x.hub.poller.poll.called)
- def test_poll_err_writable(self):
- x = X(self.app)
- writer = Mock(name='writer')
- x.hub.add_writer(6, writer, 6, 48)
- x.hub.on_tick.add(x.close_then_error(Mock(), 2))
- x.hub.poller.poll.return_value = [(6, ERR)]
- with self.assertRaises(socket.error):
- asynloop(*x.args)
- writer.assert_called_with(6, 48)
- self.assertTrue(x.hub.poller.poll.called)
- def test_poll_write_generator(self):
- x = X(self.app)
- x.hub.remove = Mock(name='hub.remove()')
- def Gen():
- yield 1
- yield 2
- gen = Gen()
- x.hub.add_writer(6, gen)
- x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
- x.hub.poller.poll.return_value = [(6, WRITE)]
- with self.assertRaises(socket.error):
- asynloop(*x.args)
- self.assertTrue(gen.gi_frame.f_lasti != -1)
- self.assertFalse(x.hub.remove.called)
- def test_poll_write_generator_stopped(self):
- x = X(self.app)
- def Gen():
- raise StopIteration()
- yield
- gen = Gen()
- x.hub.add_writer(6, gen)
- x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
- x.hub.poller.poll.return_value = [(6, WRITE)]
- x.hub.remove = Mock(name='hub.remove()')
- with self.assertRaises(socket.error):
- asynloop(*x.args)
- self.assertIsNone(gen.gi_frame)
- x.hub.remove.assert_called_with(6)
- def test_poll_write_generator_raises(self):
- x = X(self.app)
- def Gen():
- raise ValueError('foo')
- yield
- gen = Gen()
- x.hub.add_writer(6, gen)
- x.hub.remove = Mock(name='hub.remove()')
- x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
- x.hub.poller.poll.return_value = [(6, WRITE)]
- with self.assertRaises(ValueError):
- asynloop(*x.args)
- self.assertIsNone(gen.gi_frame)
- x.hub.remove.assert_called_with(6)
- def test_poll_err_readable(self):
- x = X(self.app)
- reader = Mock(name='reader')
- x.hub.add_reader(6, reader, 6, 24)
- x.hub.on_tick.add(x.close_then_error(Mock(), 2))
- x.hub.poller.poll.return_value = [(6, ERR)]
- with self.assertRaises(socket.error):
- asynloop(*x.args)
- reader.assert_called_with(6, 24)
- self.assertTrue(x.hub.poller.poll.called)
- def test_poll_raises_ValueError(self):
- x = X(self.app)
- x.hub.readers = {6: Mock()}
- x.close_then_error(x.hub.poller.poll, exc=ValueError)
- asynloop(*x.args)
- self.assertTrue(x.hub.poller.poll.called)
- class test_synloop(AppCase):
- def test_timeout_ignored(self):
- x = X(self.app)
- x.timeout_then_error(x.connection.drain_events)
- with self.assertRaises(socket.error):
- synloop(*x.args)
- self.assertEqual(x.connection.drain_events.call_count, 2)
- def test_updates_qos_when_changed(self):
- x = X(self.app)
- x.qos.prev = 2
- x.qos.value = 2
- x.timeout_then_error(x.connection.drain_events)
- with self.assertRaises(socket.error):
- synloop(*x.args)
- self.assertFalse(x.qos.update.called)
- x.qos.value = 4
- x.timeout_then_error(x.connection.drain_events)
- with self.assertRaises(socket.error):
- synloop(*x.args)
- x.qos.update.assert_called_with()
- def test_ignores_socket_errors_when_closed(self):
- x = X(self.app)
- x.close_then_error(x.connection.drain_events)
- self.assertIsNone(synloop(*x.args))
|