123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471 |
- import errno
- import socket
- import pytest
- import queue
- from case import Mock
- from kombu.async import Hub, READ, WRITE, ERR
- from celery.bootsteps import CLOSE, RUN
- from celery.exceptions import (
- InvalidTaskError, WorkerLostError, WorkerShutdown, WorkerTerminate,
- )
- from celery.platforms import EX_FAILURE
- from celery.worker import state
- from celery.worker.consumer import Consumer
- from celery.worker.loops import _quick_drain, asynloop, synloop
- class PromiseEqual:
- def __init__(self, fun, *args, **kwargs):
- self.fun = fun
- self.args = args
- self.kwargs = kwargs
- def __eq__(self, other):
- return (other.fun == self.fun and
- other.args == self.args and
- other.kwargs == self.kwargs)
- def __repr__(self):
- return '<promise: {0.fun!r} {0.args!r} {0.kwargs!r}>'.format(self)
- class X:
- def __init__(self, app, heartbeat=None, on_task_message=None,
- transport_driver_type=None):
- hub = Hub()
- (
- self.obj,
- self.connection,
- self.consumer,
- self.blueprint,
- self.hub,
- self.qos,
- self.heartbeat,
- self.clock,
- ) = self.args = [Mock(name='obj'),
- Mock(name='connection'),
- Mock(name='consumer'),
- Mock(name='blueprint'),
- hub,
- Mock(name='qos'),
- heartbeat,
- Mock(name='clock')]
- self.connection.supports_heartbeats = True
- self.connection.get_heartbeat_interval.side_effect = (
- lambda: self.heartbeat
- )
- self.consumer.callbacks = []
- self.obj.strategies = {}
- self.connection.connection_errors = (socket.error,)
- if transport_driver_type:
- self.connection.transport.driver_type = transport_driver_type
- self.hub.readers = {}
- self.hub.timer = Mock(name='hub.timer')
- self.hub.timer._queue = [Mock()]
- self.hub.fire_timers = Mock(name='hub.fire_timers')
- self.hub.fire_timers.return_value = 1.7
- self.hub.poller = Mock(name='hub.poller')
- self.hub.close = Mock(name='hub.close()') # asynloop calls hub.close
- self.Hub = self.hub
- self.blueprint.state = RUN
- # need this for create_task_handler
- self._consumer = _consumer = Consumer(
- Mock(), timer=Mock(), controller=Mock(), app=app)
- _consumer.on_task_message = on_task_message or []
- self.obj.create_task_handler = _consumer.create_task_handler
- self.on_unknown_message = self.obj.on_unknown_message = Mock(
- name='on_unknown_message',
- )
- _consumer.on_unknown_message = self.on_unknown_message
- self.on_unknown_task = self.obj.on_unknown_task = Mock(
- name='on_unknown_task',
- )
- _consumer.on_unknown_task = self.on_unknown_task
- self.on_invalid_task = self.obj.on_invalid_task = Mock(
- name='on_invalid_task',
- )
- _consumer.on_invalid_task = self.on_invalid_task
- _consumer.strategies = self.obj.strategies
- def timeout_then_error(self, mock):
- def first(*args, **kwargs):
- mock.side_effect = socket.error()
- raise socket.timeout()
- mock.side_effect = first
- def close_then_error(self, mock=None, mod=0, exc=None):
- mock = Mock() if mock is None else mock
- def first(*args, **kwargs):
- if not mod or mock.call_count > mod:
- self.close()
- raise (socket.error() if exc is None else exc)
- mock.side_effect = first
- return mock
- def close(self, *args, **kwargs):
- self.blueprint.state = CLOSE
- def closer(self, mock=None, mod=0):
- mock = Mock() if mock is None else mock
- def closing(*args, **kwargs):
- if not mod or mock.call_count >= mod:
- self.close()
- mock.side_effect = closing
- return mock
- def get_task_callback(*args, **kwargs):
- x = X(*args, **kwargs)
- x.blueprint.state = CLOSE
- asynloop(*x.args)
- return x, x.consumer.on_message
- class test_asynloop:
- def setup(self):
- @self.app.task(shared=False)
- def add(x, y):
- return x + y
- self.add = add
- def test_drain_after_consume(self):
- x, _ = get_task_callback(self.app, transport_driver_type='amqp')
- assert _quick_drain in [p.fun for p in x.hub._ready]
- def test_pool_did_not_start_at_startup(self):
- x = X(self.app)
- x.obj.restart_count = 0
- x.obj.pool.did_start_ok.return_value = False
- with pytest.raises(WorkerLostError):
- asynloop(*x.args)
- def test_setup_heartbeat(self):
- x = X(self.app, heartbeat=10)
- x.hub.timer.call_repeatedly = Mock(name='x.hub.call_repeatedly()')
- x.blueprint.state = CLOSE
- asynloop(*x.args)
- x.consumer.consume.assert_called_with()
- x.obj.on_ready.assert_called_with()
- x.hub.timer.call_repeatedly.assert_called_with(
- 10 / 2.0, x.connection.heartbeat_check, (2.0,),
- )
- def task_context(self, sig, **kwargs):
- x, on_task = get_task_callback(self.app, **kwargs)
- message = self.task_message_from_sig(self.app, sig)
- strategy = x.obj.strategies[sig.task] = Mock(name='strategy')
- return x, on_task, message, strategy
- def test_on_task_received(self):
- x, on_task, msg, strategy = self.task_context(self.add.s(2, 2))
- on_task(msg)
- strategy.assert_called_with(
- msg, None,
- PromiseEqual(x._consumer.call_soon, msg.ack_log_error),
- PromiseEqual(x._consumer.call_soon, msg.reject_log_error), [],
- )
- def test_on_task_received_executes_on_task_message(self):
- cbs = [Mock(), Mock(), Mock()]
- x, on_task, msg, strategy = self.task_context(
- self.add.s(2, 2), on_task_message=cbs,
- )
- on_task(msg)
- strategy.assert_called_with(
- msg, None,
- PromiseEqual(x._consumer.call_soon, msg.ack_log_error),
- PromiseEqual(x._consumer.call_soon, msg.reject_log_error),
- cbs,
- )
- def test_on_task_message_missing_name(self):
- x, on_task, msg, strategy = self.task_context(self.add.s(2, 2))
- msg.headers.pop('task')
- on_task(msg)
- x.on_unknown_message.assert_called_with(msg.decode(), msg)
- def test_on_task_pool_raises(self):
- x, on_task, msg, strategy = self.task_context(self.add.s(2, 2))
- strategy.side_effect = ValueError()
- with pytest.raises(ValueError):
- on_task(msg)
- def test_on_task_InvalidTaskError(self):
- x, on_task, msg, strategy = self.task_context(self.add.s(2, 2))
- exc = strategy.side_effect = InvalidTaskError()
- on_task(msg)
- x.on_invalid_task.assert_called_with(None, msg, exc)
- def test_should_terminate(self):
- x = X(self.app)
- # XXX why aren't the errors propagated?!?
- state.should_terminate = True
- try:
- with pytest.raises(WorkerTerminate):
- asynloop(*x.args)
- finally:
- state.should_terminate = None
- def test_should_terminate_hub_close_raises(self):
- x = X(self.app)
- # XXX why aren't the errors propagated?!?
- state.should_terminate = EX_FAILURE
- x.hub.close.side_effect = MemoryError()
- try:
- with pytest.raises(WorkerTerminate):
- asynloop(*x.args)
- finally:
- state.should_terminate = None
- def test_should_stop(self):
- x = X(self.app)
- state.should_stop = 303
- try:
- with pytest.raises(WorkerShutdown):
- asynloop(*x.args)
- finally:
- state.should_stop = None
- def test_updates_qos(self):
- x = X(self.app)
- x.qos.prev = 3
- x.qos.value = 3
- x.hub.on_tick.add(x.closer(mod=2))
- x.hub.timer._queue = [1]
- asynloop(*x.args)
- x.qos.update.assert_not_called()
- x = X(self.app)
- x.qos.prev = 1
- x.qos.value = 6
- x.hub.on_tick.add(x.closer(mod=2))
- asynloop(*x.args)
- x.qos.update.assert_called_with()
- x.hub.fire_timers.assert_called_with(propagate=(socket.error,))
- def test_poll_empty(self):
- x = X(self.app)
- x.hub.readers = {6: Mock()}
- x.hub.timer._queue = [1]
- x.close_then_error(x.hub.poller.poll)
- x.hub.fire_timers.return_value = 33.37
- poller = x.hub.poller
- poller.poll.return_value = []
- with pytest.raises(socket.error):
- asynloop(*x.args)
- poller.poll.assert_called_with(33.37)
- def test_poll_readable(self):
- x = X(self.app)
- reader = Mock(name='reader')
- x.hub.add_reader(6, reader, 6)
- x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), mod=4))
- poller = x.hub.poller
- poller.poll.return_value = [(6, READ)]
- with pytest.raises(socket.error):
- asynloop(*x.args)
- reader.assert_called_with(6)
- poller.poll.assert_called()
- def test_poll_readable_raises_Empty(self):
- x = X(self.app)
- reader = Mock(name='reader')
- x.hub.add_reader(6, reader, 6)
- x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
- poller = x.hub.poller
- poller.poll.return_value = [(6, READ)]
- reader.side_effect = queue.Empty()
- with pytest.raises(socket.error):
- asynloop(*x.args)
- reader.assert_called_with(6)
- poller.poll.assert_called()
- def test_poll_writable(self):
- x = X(self.app)
- writer = Mock(name='writer')
- x.hub.add_writer(6, writer, 6)
- x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
- poller = x.hub.poller
- poller.poll.return_value = [(6, WRITE)]
- with pytest.raises(socket.error):
- asynloop(*x.args)
- writer.assert_called_with(6)
- poller.poll.assert_called()
- def test_poll_writable_none_registered(self):
- x = X(self.app)
- writer = Mock(name='writer')
- x.hub.add_writer(6, writer, 6)
- x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
- poller = x.hub.poller
- poller.poll.return_value = [(7, WRITE)]
- with pytest.raises(socket.error):
- asynloop(*x.args)
- poller.poll.assert_called()
- def test_poll_unknown_event(self):
- x = X(self.app)
- writer = Mock(name='reader')
- x.hub.add_writer(6, writer, 6)
- x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
- poller = x.hub.poller
- poller.poll.return_value = [(6, 0)]
- with pytest.raises(socket.error):
- asynloop(*x.args)
- poller.poll.assert_called()
- def test_poll_keep_draining_disabled(self):
- x = X(self.app)
- x.hub.writers = {6: Mock()}
- poll = x.hub.poller.poll
- def se(*args, **kwargs):
- poll.side_effect = socket.error()
- poll.side_effect = se
- poller = x.hub.poller
- poll.return_value = [(6, 0)]
- with pytest.raises(socket.error):
- asynloop(*x.args)
- poller.poll.assert_called()
- def test_poll_err_writable(self):
- x = X(self.app)
- writer = Mock(name='writer')
- x.hub.add_writer(6, writer, 6, 48)
- x.hub.on_tick.add(x.close_then_error(Mock(), 2))
- poller = x.hub.poller
- poller.poll.return_value = [(6, ERR)]
- with pytest.raises(socket.error):
- asynloop(*x.args)
- writer.assert_called_with(6, 48)
- poller.poll.assert_called()
- def test_poll_write_generator(self):
- x = X(self.app)
- x.hub.remove = Mock(name='hub.remove()')
- def Gen():
- yield 1
- yield 2
- gen = Gen()
- x.hub.add_writer(6, gen)
- x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
- x.hub.poller.poll.return_value = [(6, WRITE)]
- with pytest.raises(socket.error):
- asynloop(*x.args)
- assert gen.gi_frame.f_lasti != -1
- x.hub.remove.assert_not_called()
- def test_poll_write_generator_stopped(self):
- x = X(self.app)
- def Gen():
- raise StopIteration()
- yield
- gen = Gen()
- x.hub.add_writer(6, gen)
- x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
- x.hub.poller.poll.return_value = [(6, WRITE)]
- x.hub.remove = Mock(name='hub.remove()')
- with pytest.raises(socket.error):
- asynloop(*x.args)
- assert gen.gi_frame is None
- def test_poll_write_generator_raises(self):
- x = X(self.app)
- def Gen():
- raise ValueError('foo')
- yield
- gen = Gen()
- x.hub.add_writer(6, gen)
- x.hub.remove = Mock(name='hub.remove()')
- x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
- x.hub.poller.poll.return_value = [(6, WRITE)]
- with pytest.raises(ValueError):
- asynloop(*x.args)
- assert gen.gi_frame is None
- x.hub.remove.assert_called_with(6)
- def test_poll_err_readable(self):
- x = X(self.app)
- reader = Mock(name='reader')
- x.hub.add_reader(6, reader, 6, 24)
- x.hub.on_tick.add(x.close_then_error(Mock(), 2))
- poller = x.hub.poller
- poller.poll.return_value = [(6, ERR)]
- with pytest.raises(socket.error):
- asynloop(*x.args)
- reader.assert_called_with(6, 24)
- poller.poll.assert_called()
- def test_poll_raises_ValueError(self):
- x = X(self.app)
- x.hub.readers = {6: Mock()}
- poller = x.hub.poller
- x.close_then_error(poller.poll, exc=ValueError)
- asynloop(*x.args)
- poller.poll.assert_called()
- class test_synloop:
- def test_timeout_ignored(self):
- x = X(self.app)
- x.timeout_then_error(x.connection.drain_events)
- with pytest.raises(socket.error):
- synloop(*x.args)
- assert x.connection.drain_events.call_count == 2
- def test_updates_qos_when_changed(self):
- x = X(self.app)
- x.qos.prev = 2
- x.qos.value = 2
- x.timeout_then_error(x.connection.drain_events)
- with pytest.raises(socket.error):
- synloop(*x.args)
- x.qos.update.assert_not_called()
- x.qos.value = 4
- x.timeout_then_error(x.connection.drain_events)
- with pytest.raises(socket.error):
- synloop(*x.args)
- x.qos.update.assert_called_with()
- def test_ignores_socket_errors_when_closed(self):
- x = X(self.app)
- x.close_then_error(x.connection.drain_events)
- assert synloop(*x.args) is None
- class test_quick_drain:
- def setup(self):
- self.connection = Mock(name='connection')
- def test_drain(self):
- _quick_drain(self.connection, timeout=33.3)
- self.connection.drain_events.assert_called_with(timeout=33.3)
- def test_drain_error(self):
- exc = KeyError()
- exc.errno = 313
- self.connection.drain_events.side_effect = exc
- with pytest.raises(KeyError):
- _quick_drain(self.connection, timeout=33.3)
- def test_drain_error_EAGAIN(self):
- exc = KeyError()
- exc.errno = errno.EAGAIN
- self.connection.drain_events.side_effect = exc
- _quick_drain(self.connection, timeout=33.3)
|