123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368 |
- from __future__ import absolute_import
- import socket
- from mock import Mock
- from celery.exceptions import InvalidTaskError, SystemTerminate
- from celery.five import Empty
- from celery.worker import state
- from celery.worker.loops import asynloop, synloop, CLOSE, READ, WRITE, ERR
- from celery.tests.utils import AppCase, body_from_sig
- class X(object):
- def __init__(self, heartbeat=None, on_task=None):
- (
- self.obj,
- self.connection,
- self.consumer,
- self.strategies,
- self.blueprint,
- self.hub,
- self.qos,
- self.heartbeat,
- self.handle_unknown_message,
- self.handle_unknown_task,
- self.handle_invalid_task,
- self.clock,
- ) = self.args = [Mock(name='obj'),
- Mock(name='connection'),
- Mock(name='consumer'),
- {},
- Mock(name='blueprint'),
- Mock(name='Hub'),
- Mock(name='qos'),
- heartbeat,
- Mock(name='handle_unknown_message'),
- Mock(name='handle_unknown_task'),
- Mock(name='handle_invalid_task'),
- Mock(name='clock')]
- self.connection.supports_heartbeats = True
- self.consumer.callbacks = []
- self.connection.connection_errors = (socket.error, )
- #hent = self.Hub.__enter__ = Mock(name='Hub.__enter__')
- #self.Hub.__exit__ = Mock(name='Hub.__exit__')
- #self.hub = hent.return_value = Mock(name='hub_context')
- self.hub.on_task = on_task or []
- self.hub.readers = {}
- self.hub.writers = {}
- self.hub.fire_timers.return_value = 1.7
- self.Hub = self.hub
- def timeout_then_error(self, mock):
- def first(*args, **kwargs):
- mock.side_effect = socket.error()
- self.connection.more_to_read = False
- raise socket.timeout()
- mock.side_effect = first
- def close_then_error(self, mock):
- def first(*args, **kwargs):
- self.close()
- self.connection.more_to_read = False
- raise socket.error()
- mock.side_effect = first
- def close(self, *args, **kwargs):
- self.blueprint.state = CLOSE
- def closer(self, mock=None):
- mock = Mock() if mock is None else mock
- mock.side_effect = self.close
- return mock
- def get_task_callback(*args, **kwargs):
- x = X(*args, **kwargs)
- x.blueprint.state = CLOSE
- asynloop(*x.args)
- return x, x.consumer.callbacks[0]
- class test_asynloop(AppCase):
- def setup(self):
- @self.app.task()
- def add(x, y):
- return x + y
- self.add = add
- def test_setup_heartbeat(self):
- x = X(heartbeat=10)
- x.blueprint.state = CLOSE
- asynloop(*x.args)
- x.consumer.consume.assert_called_with()
- x.obj.on_ready.assert_called_with()
- x.hub.timer.apply_interval.assert_called_with(
- 10 * 1000.0 / 2.0, x.connection.heartbeat_check, (2.0, ),
- )
- def task_context(self, sig, **kwargs):
- x, on_task = get_task_callback(**kwargs)
- body = body_from_sig(self.app, sig)
- message = Mock()
- strategy = x.strategies[sig.task] = Mock()
- return x, on_task, body, message, strategy
- def test_on_task_received(self):
- _, on_task, body, msg, strategy = self.task_context(self.add.s(2, 2))
- on_task(body, msg)
- strategy.assert_called_with(msg, body, msg.ack_log_error)
- def test_on_task_received_executes_hub_on_task(self):
- cbs = [Mock(), Mock(), Mock()]
- _, on_task, body, msg, _ = self.task_context(
- self.add.s(2, 2), on_task=cbs,
- )
- on_task(body, msg)
- [cb.assert_called_with() for cb in cbs]
- def test_on_task_message_missing_name(self):
- x, on_task, body, msg, strategy = self.task_context(self.add.s(2, 2))
- body.pop('task')
- on_task(body, msg)
- x.handle_unknown_message.assert_called_with(body, msg)
- def test_on_task_not_registered(self):
- x, on_task, body, msg, strategy = self.task_context(self.add.s(2, 2))
- exc = strategy.side_effect = KeyError(self.add.name)
- on_task(body, msg)
- x.handle_unknown_task.assert_called_with(body, msg, exc)
- def test_on_task_InvalidTaskError(self):
- x, on_task, body, msg, strategy = self.task_context(self.add.s(2, 2))
- exc = strategy.side_effect = InvalidTaskError()
- on_task(body, msg)
- x.handle_invalid_task.assert_called_with(body, msg, exc)
- def test_should_terminate(self):
- x = X()
- # XXX why aren't the errors propagated?!?
- state.should_terminate = True
- try:
- with self.assertRaises(SystemTerminate):
- asynloop(*x.args)
- finally:
- state.should_terminate = False
- def test_should_terminate_hub_close_raises(self):
- x = X()
- # XXX why aren't the errors propagated?!?
- state.should_terminate = True
- x.hub.close.side_effect = MemoryError()
- try:
- with self.assertRaises(SystemTerminate):
- asynloop(*x.args)
- finally:
- state.should_terminate = False
- def test_should_stop(self):
- x = X()
- state.should_stop = True
- try:
- with self.assertRaises(SystemExit):
- asynloop(*x.args)
- finally:
- state.should_stop = False
- def test_updates_qos(self):
- x = X()
- x.qos.prev = 3
- x.qos.value = 3
- asynloop(*x.args, sleep=x.closer())
- self.assertFalse(x.qos.update.called)
- x = X()
- x.qos.prev = 1
- x.qos.value = 6
- asynloop(*x.args, sleep=x.closer())
- x.qos.update.assert_called_with()
- x.hub.fire_timers.assert_called_with(propagate=(socket.error, ))
- x.connection.transport.on_poll_start.assert_called_with()
- def test_poll_empty(self):
- x = X()
- x.hub.readers = {6: Mock()}
- x.close_then_error(x.connection.drain_nowait)
- x.hub.fire_timers.return_value = 33.37
- x.hub.poller.poll.return_value = []
- with self.assertRaises(socket.error):
- asynloop(*x.args)
- x.hub.poller.poll.assert_called_with(33.37)
- x.connection.transport.on_poll_empty.assert_called_with()
- def test_poll_readable(self):
- x = X()
- x.hub.readers = {6: Mock()}
- x.close_then_error(x.connection.drain_nowait)
- x.hub.poller.poll.return_value = [(6, READ)]
- with self.assertRaises(socket.error):
- asynloop(*x.args)
- x.hub.readers[6].assert_called_with(6, READ)
- self.assertTrue(x.hub.poller.poll.called)
- def test_poll_readable_raises_Empty(self):
- x = X()
- x.hub.readers = {6: Mock()}
- x.close_then_error(x.connection.drain_nowait)
- x.hub.poller.poll.return_value = [(6, READ)]
- x.hub.readers[6].side_effect = Empty()
- with self.assertRaises(socket.error):
- asynloop(*x.args)
- x.hub.readers[6].assert_called_with(6, READ)
- self.assertTrue(x.hub.poller.poll.called)
- def test_poll_writable(self):
- x = X()
- x.hub.writers = {6: Mock()}
- x.close_then_error(x.connection.drain_nowait)
- x.hub.poller.poll.return_value = [(6, WRITE)]
- with self.assertRaises(socket.error):
- asynloop(*x.args)
- x.hub.writers[6].assert_called_with(6, WRITE)
- self.assertTrue(x.hub.poller.poll.called)
- def test_poll_writable_none_registered(self):
- x = X()
- x.hub.writers = {6: Mock()}
- x.close_then_error(x.connection.drain_nowait)
- x.hub.poller.poll.return_value = [(7, WRITE)]
- with self.assertRaises(socket.error):
- asynloop(*x.args)
- self.assertTrue(x.hub.poller.poll.called)
- def test_poll_unknown_event(self):
- x = X()
- x.hub.writers = {6: Mock()}
- x.close_then_error(x.connection.drain_nowait)
- x.hub.poller.poll.return_value = [(6, 0)]
- with self.assertRaises(socket.error):
- asynloop(*x.args)
- self.assertTrue(x.hub.poller.poll.called)
- def test_poll_keep_draining_disabled(self):
- x = X()
- x.hub.writers = {6: Mock()}
- poll = x.hub.poller.poll
- def se(*args, **kwargs):
- poll.side_effect = socket.error()
- poll.side_effect = se
- x.connection.transport.nb_keep_draining = False
- x.close_then_error(x.connection.drain_nowait)
- x.hub.poller.poll.return_value = [(6, 0)]
- with self.assertRaises(socket.error):
- asynloop(*x.args)
- self.assertTrue(x.hub.poller.poll.called)
- self.assertFalse(x.connection.drain_nowait.called)
- def test_poll_err_writable(self):
- x = X()
- x.hub.writers = {6: Mock()}
- x.close_then_error(x.connection.drain_nowait)
- x.hub.poller.poll.return_value = [(6, ERR)]
- with self.assertRaises(socket.error):
- asynloop(*x.args)
- x.hub.writers[6].assert_called_with(6, ERR)
- self.assertTrue(x.hub.poller.poll.called)
- def test_poll_write_generator(self):
- x = X()
- def Gen():
- yield 1
- yield 2
- gen = Gen()
- x.hub.writers = {6: gen}
- x.close_then_error(x.connection.drain_nowait)
- x.hub.poller.poll.return_value = [(6, WRITE)]
- with self.assertRaises(socket.error):
- asynloop(*x.args)
- self.assertTrue(gen.gi_frame.f_lasti != -1)
- self.assertFalse(x.hub.remove.called)
- def test_poll_write_generator_stopped(self):
- x = X()
- def Gen():
- raise StopIteration()
- yield
- gen = Gen()
- x.hub.writers = {6: gen}
- x.close_then_error(x.connection.drain_nowait)
- x.hub.poller.poll.return_value = [(6, WRITE)]
- with self.assertRaises(socket.error):
- asynloop(*x.args)
- self.assertIsNone(gen.gi_frame)
- x.hub.remove.assert_called_with(6)
- def test_poll_write_generator_raises(self):
- x = X()
- def Gen():
- raise ValueError('foo')
- yield
- gen = Gen()
- x.hub.writers = {6: gen}
- x.close_then_error(x.connection.drain_nowait)
- x.hub.poller.poll.return_value = [(6, WRITE)]
- with self.assertRaises(ValueError):
- asynloop(*x.args)
- self.assertIsNone(gen.gi_frame)
- x.hub.remove.assert_called_with(6)
- def test_poll_err_readable(self):
- x = X()
- x.hub.readers = {6: Mock()}
- x.close_then_error(x.connection.drain_nowait)
- x.hub.poller.poll.return_value = [(6, ERR)]
- with self.assertRaises(socket.error):
- asynloop(*x.args)
- x.hub.readers[6].assert_called_with(6, ERR)
- self.assertTrue(x.hub.poller.poll.called)
- def test_poll_raises_ValueError(self):
- x = X()
- x.hub.readers = {6: Mock()}
- x.close_then_error(x.connection.drain_nowait)
- x.hub.poller.poll.side_effect = ValueError()
- asynloop(*x.args)
- self.assertTrue(x.hub.poller.poll.called)
- class test_synloop(AppCase):
- def test_timeout_ignored(self):
- x = X()
- x.timeout_then_error(x.connection.drain_events)
- with self.assertRaises(socket.error):
- synloop(*x.args)
- self.assertEqual(x.connection.drain_events.call_count, 2)
- def test_updates_qos_when_changed(self):
- x = X()
- x.qos.prev = 2
- x.qos.value = 2
- x.timeout_then_error(x.connection.drain_events)
- with self.assertRaises(socket.error):
- synloop(*x.args)
- self.assertFalse(x.qos.update.called)
- x.qos.value = 4
- x.timeout_then_error(x.connection.drain_events)
- with self.assertRaises(socket.error):
- synloop(*x.args)
- x.qos.update.assert_called_with()
- def test_ignores_socket_errors_when_closed(self):
- x = X()
- x.close_then_error(x.connection.drain_events)
- self.assertIsNone(synloop(*x.args))
|