| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472 | 
							- from __future__ import absolute_import, unicode_literals
 
- import errno
 
- import socket
 
- import pytest
 
- from case import Mock
 
- from celery.bootsteps import CLOSE, RUN
 
- from celery.exceptions import (InvalidTaskError, WorkerLostError,
 
-                                WorkerShutdown, WorkerTerminate)
 
- from celery.five import Empty, python_2_unicode_compatible
 
- from celery.platforms import EX_FAILURE
 
- from celery.worker import state
 
- from celery.worker.consumer import Consumer
 
- from celery.worker.loops import _quick_drain, asynloop, synloop
 
- from kombu.async import ERR, READ, WRITE, Hub
 
- @python_2_unicode_compatible
 
- class PromiseEqual(object):
 
-     def __init__(self, fun, *args, **kwargs):
 
-         self.fun = fun
 
-         self.args = args
 
-         self.kwargs = kwargs
 
-     def __eq__(self, other):
 
-         return (other.fun == self.fun and
 
-                 other.args == self.args and
 
-                 other.kwargs == self.kwargs)
 
-     def __repr__(self):
 
-         return '<promise: {0.fun!r} {0.args!r} {0.kwargs!r}>'.format(self)
 
- class X(object):
 
-     def __init__(self, app, heartbeat=None, on_task_message=None,
 
-                  transport_driver_type=None):
 
-         hub = Hub()
 
-         (
 
-             self.obj,
 
-             self.connection,
 
-             self.consumer,
 
-             self.blueprint,
 
-             self.hub,
 
-             self.qos,
 
-             self.heartbeat,
 
-             self.clock,
 
-         ) = self.args = [Mock(name='obj'),
 
-                          Mock(name='connection'),
 
-                          Mock(name='consumer'),
 
-                          Mock(name='blueprint'),
 
-                          hub,
 
-                          Mock(name='qos'),
 
-                          heartbeat,
 
-                          Mock(name='clock')]
 
-         self.connection.supports_heartbeats = True
 
-         self.connection.get_heartbeat_interval.side_effect = (
 
-             lambda: self.heartbeat
 
-         )
 
-         self.consumer.callbacks = []
 
-         self.obj.strategies = {}
 
-         self.connection.connection_errors = (socket.error,)
 
-         if transport_driver_type:
 
-             self.connection.transport.driver_type = transport_driver_type
 
-         self.hub.readers = {}
 
-         self.hub.timer = Mock(name='hub.timer')
 
-         self.hub.timer._queue = [Mock()]
 
-         self.hub.fire_timers = Mock(name='hub.fire_timers')
 
-         self.hub.fire_timers.return_value = 1.7
 
-         self.hub.poller = Mock(name='hub.poller')
 
-         self.hub.close = Mock(name='hub.close()')  # asynloop calls hub.close
 
-         self.Hub = self.hub
 
-         self.blueprint.state = RUN
 
-         # need this for create_task_handler
 
-         self._consumer = _consumer = Consumer(
 
-             Mock(), timer=Mock(), controller=Mock(), app=app)
 
-         _consumer.on_task_message = on_task_message or []
 
-         self.obj.create_task_handler = _consumer.create_task_handler
 
-         self.on_unknown_message = self.obj.on_unknown_message = Mock(
 
-             name='on_unknown_message',
 
-         )
 
-         _consumer.on_unknown_message = self.on_unknown_message
 
-         self.on_unknown_task = self.obj.on_unknown_task = Mock(
 
-             name='on_unknown_task',
 
-         )
 
-         _consumer.on_unknown_task = self.on_unknown_task
 
-         self.on_invalid_task = self.obj.on_invalid_task = Mock(
 
-             name='on_invalid_task',
 
-         )
 
-         _consumer.on_invalid_task = self.on_invalid_task
 
-         _consumer.strategies = self.obj.strategies
 
-     def timeout_then_error(self, mock):
 
-         def first(*args, **kwargs):
 
-             mock.side_effect = socket.error()
 
-             raise socket.timeout()
 
-         mock.side_effect = first
 
-     def close_then_error(self, mock=None, mod=0, exc=None):
 
-         mock = Mock() if mock is None else mock
 
-         def first(*args, **kwargs):
 
-             if not mod or mock.call_count > mod:
 
-                 self.close()
 
-                 raise (socket.error() if exc is None else exc)
 
-         mock.side_effect = first
 
-         return mock
 
-     def close(self, *args, **kwargs):
 
-         self.blueprint.state = CLOSE
 
-     def closer(self, mock=None, mod=0):
 
-         mock = Mock() if mock is None else mock
 
-         def closing(*args, **kwargs):
 
-             if not mod or mock.call_count >= mod:
 
-                 self.close()
 
-         mock.side_effect = closing
 
-         return mock
 
- def get_task_callback(*args, **kwargs):
 
-     x = X(*args, **kwargs)
 
-     x.blueprint.state = CLOSE
 
-     asynloop(*x.args)
 
-     return x, x.consumer.on_message
 
- class test_asynloop:
 
-     def setup(self):
 
-         @self.app.task(shared=False)
 
-         def add(x, y):
 
-             return x + y
 
-         self.add = add
 
-     def test_drain_after_consume(self):
 
-         x, _ = get_task_callback(self.app, transport_driver_type='amqp')
 
-         assert _quick_drain in [p.fun for p in x.hub._ready]
 
-     def test_pool_did_not_start_at_startup(self):
 
-         x = X(self.app)
 
-         x.obj.restart_count = 0
 
-         x.obj.pool.did_start_ok.return_value = False
 
-         with pytest.raises(WorkerLostError):
 
-             asynloop(*x.args)
 
-     def test_setup_heartbeat(self):
 
-         x = X(self.app, heartbeat=10)
 
-         x.hub.timer.call_repeatedly = Mock(name='x.hub.call_repeatedly()')
 
-         x.blueprint.state = CLOSE
 
-         asynloop(*x.args)
 
-         x.consumer.consume.assert_called_with()
 
-         x.obj.on_ready.assert_called_with()
 
-         x.hub.timer.call_repeatedly.assert_called_with(
 
-             10 / 2.0, x.connection.heartbeat_check, (2.0,),
 
-         )
 
-     def task_context(self, sig, **kwargs):
 
-         x, on_task = get_task_callback(self.app, **kwargs)
 
-         message = self.task_message_from_sig(self.app, sig)
 
-         strategy = x.obj.strategies[sig.task] = Mock(name='strategy')
 
-         return x, on_task, message, strategy
 
-     def test_on_task_received(self):
 
-         x, on_task, msg, strategy = self.task_context(self.add.s(2, 2))
 
-         on_task(msg)
 
-         strategy.assert_called_with(
 
-             msg, None,
 
-             PromiseEqual(x._consumer.call_soon, msg.ack_log_error),
 
-             PromiseEqual(x._consumer.call_soon, msg.reject_log_error), [],
 
-         )
 
-     def test_on_task_received_executes_on_task_message(self):
 
-         cbs = [Mock(), Mock(), Mock()]
 
-         x, on_task, msg, strategy = self.task_context(
 
-             self.add.s(2, 2), on_task_message=cbs,
 
-         )
 
-         on_task(msg)
 
-         strategy.assert_called_with(
 
-             msg, None,
 
-             PromiseEqual(x._consumer.call_soon, msg.ack_log_error),
 
-             PromiseEqual(x._consumer.call_soon, msg.reject_log_error),
 
-             cbs,
 
-         )
 
-     def test_on_task_message_missing_name(self):
 
-         x, on_task, msg, strategy = self.task_context(self.add.s(2, 2))
 
-         msg.headers.pop('task')
 
-         on_task(msg)
 
-         x.on_unknown_message.assert_called_with(msg.decode(), msg)
 
-     def test_on_task_pool_raises(self):
 
-         x, on_task, msg, strategy = self.task_context(self.add.s(2, 2))
 
-         strategy.side_effect = ValueError()
 
-         with pytest.raises(ValueError):
 
-             on_task(msg)
 
-     def test_on_task_InvalidTaskError(self):
 
-         x, on_task, msg, strategy = self.task_context(self.add.s(2, 2))
 
-         exc = strategy.side_effect = InvalidTaskError()
 
-         on_task(msg)
 
-         x.on_invalid_task.assert_called_with(None, msg, exc)
 
-     def test_should_terminate(self):
 
-         x = X(self.app)
 
-         # XXX why aren't the errors propagated?!?
 
-         state.should_terminate = True
 
-         try:
 
-             with pytest.raises(WorkerTerminate):
 
-                 asynloop(*x.args)
 
-         finally:
 
-             state.should_terminate = None
 
-     def test_should_terminate_hub_close_raises(self):
 
-         x = X(self.app)
 
-         # XXX why aren't the errors propagated?!?
 
-         state.should_terminate = EX_FAILURE
 
-         x.hub.close.side_effect = MemoryError()
 
-         try:
 
-             with pytest.raises(WorkerTerminate):
 
-                 asynloop(*x.args)
 
-         finally:
 
-             state.should_terminate = None
 
-     def test_should_stop(self):
 
-         x = X(self.app)
 
-         state.should_stop = 303
 
-         try:
 
-             with pytest.raises(WorkerShutdown):
 
-                 asynloop(*x.args)
 
-         finally:
 
-             state.should_stop = None
 
-     def test_updates_qos(self):
 
-         x = X(self.app)
 
-         x.qos.prev = 3
 
-         x.qos.value = 3
 
-         x.hub.on_tick.add(x.closer(mod=2))
 
-         x.hub.timer._queue = [1]
 
-         asynloop(*x.args)
 
-         x.qos.update.assert_not_called()
 
-         x = X(self.app)
 
-         x.qos.prev = 1
 
-         x.qos.value = 6
 
-         x.hub.on_tick.add(x.closer(mod=2))
 
-         asynloop(*x.args)
 
-         x.qos.update.assert_called_with()
 
-         x.hub.fire_timers.assert_called_with(propagate=(socket.error,))
 
-     def test_poll_empty(self):
 
-         x = X(self.app)
 
-         x.hub.readers = {6: Mock()}
 
-         x.hub.timer._queue = [1]
 
-         x.close_then_error(x.hub.poller.poll)
 
-         x.hub.fire_timers.return_value = 33.37
 
-         poller = x.hub.poller
 
-         poller.poll.return_value = []
 
-         with pytest.raises(socket.error):
 
-             asynloop(*x.args)
 
-         poller.poll.assert_called_with(33.37)
 
-     def test_poll_readable(self):
 
-         x = X(self.app)
 
-         reader = Mock(name='reader')
 
-         x.hub.add_reader(6, reader, 6)
 
-         x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), mod=4))
 
-         poller = x.hub.poller
 
-         poller.poll.return_value = [(6, READ)]
 
-         with pytest.raises(socket.error):
 
-             asynloop(*x.args)
 
-         reader.assert_called_with(6)
 
-         poller.poll.assert_called()
 
-     def test_poll_readable_raises_Empty(self):
 
-         x = X(self.app)
 
-         reader = Mock(name='reader')
 
-         x.hub.add_reader(6, reader, 6)
 
-         x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
 
-         poller = x.hub.poller
 
-         poller.poll.return_value = [(6, READ)]
 
-         reader.side_effect = Empty()
 
-         with pytest.raises(socket.error):
 
-             asynloop(*x.args)
 
-         reader.assert_called_with(6)
 
-         poller.poll.assert_called()
 
-     def test_poll_writable(self):
 
-         x = X(self.app)
 
-         writer = Mock(name='writer')
 
-         x.hub.add_writer(6, writer, 6)
 
-         x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
 
-         poller = x.hub.poller
 
-         poller.poll.return_value = [(6, WRITE)]
 
-         with pytest.raises(socket.error):
 
-             asynloop(*x.args)
 
-         writer.assert_called_with(6)
 
-         poller.poll.assert_called()
 
-     def test_poll_writable_none_registered(self):
 
-         x = X(self.app)
 
-         writer = Mock(name='writer')
 
-         x.hub.add_writer(6, writer, 6)
 
-         x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
 
-         poller = x.hub.poller
 
-         poller.poll.return_value = [(7, WRITE)]
 
-         with pytest.raises(socket.error):
 
-             asynloop(*x.args)
 
-         poller.poll.assert_called()
 
-     def test_poll_unknown_event(self):
 
-         x = X(self.app)
 
-         writer = Mock(name='reader')
 
-         x.hub.add_writer(6, writer, 6)
 
-         x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
 
-         poller = x.hub.poller
 
-         poller.poll.return_value = [(6, 0)]
 
-         with pytest.raises(socket.error):
 
-             asynloop(*x.args)
 
-         poller.poll.assert_called()
 
-     def test_poll_keep_draining_disabled(self):
 
-         x = X(self.app)
 
-         x.hub.writers = {6: Mock()}
 
-         poll = x.hub.poller.poll
 
-         def se(*args, **kwargs):
 
-             poll.side_effect = socket.error()
 
-         poll.side_effect = se
 
-         poller = x.hub.poller
 
-         poll.return_value = [(6, 0)]
 
-         with pytest.raises(socket.error):
 
-             asynloop(*x.args)
 
-         poller.poll.assert_called()
 
-     def test_poll_err_writable(self):
 
-         x = X(self.app)
 
-         writer = Mock(name='writer')
 
-         x.hub.add_writer(6, writer, 6, 48)
 
-         x.hub.on_tick.add(x.close_then_error(Mock(), 2))
 
-         poller = x.hub.poller
 
-         poller.poll.return_value = [(6, ERR)]
 
-         with pytest.raises(socket.error):
 
-             asynloop(*x.args)
 
-         writer.assert_called_with(6, 48)
 
-         poller.poll.assert_called()
 
-     def test_poll_write_generator(self):
 
-         x = X(self.app)
 
-         x.hub.remove = Mock(name='hub.remove()')
 
-         def Gen():
 
-             yield 1
 
-             yield 2
 
-         gen = Gen()
 
-         x.hub.add_writer(6, gen)
 
-         x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
 
-         x.hub.poller.poll.return_value = [(6, WRITE)]
 
-         with pytest.raises(socket.error):
 
-             asynloop(*x.args)
 
-         assert gen.gi_frame.f_lasti != -1
 
-         x.hub.remove.assert_not_called()
 
-     def test_poll_write_generator_stopped(self):
 
-         x = X(self.app)
 
-         def Gen():
 
-             raise StopIteration()
 
-             yield
 
-         gen = Gen()
 
-         x.hub.add_writer(6, gen)
 
-         x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
 
-         x.hub.poller.poll.return_value = [(6, WRITE)]
 
-         x.hub.remove = Mock(name='hub.remove()')
 
-         with pytest.raises(socket.error):
 
-             asynloop(*x.args)
 
-         assert gen.gi_frame is None
 
-     def test_poll_write_generator_raises(self):
 
-         x = X(self.app)
 
-         def Gen():
 
-             raise ValueError('foo')
 
-             yield
 
-         gen = Gen()
 
-         x.hub.add_writer(6, gen)
 
-         x.hub.remove = Mock(name='hub.remove()')
 
-         x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
 
-         x.hub.poller.poll.return_value = [(6, WRITE)]
 
-         with pytest.raises(ValueError):
 
-             asynloop(*x.args)
 
-         assert gen.gi_frame is None
 
-         x.hub.remove.assert_called_with(6)
 
-     def test_poll_err_readable(self):
 
-         x = X(self.app)
 
-         reader = Mock(name='reader')
 
-         x.hub.add_reader(6, reader, 6, 24)
 
-         x.hub.on_tick.add(x.close_then_error(Mock(), 2))
 
-         poller = x.hub.poller
 
-         poller.poll.return_value = [(6, ERR)]
 
-         with pytest.raises(socket.error):
 
-             asynloop(*x.args)
 
-         reader.assert_called_with(6, 24)
 
-         poller.poll.assert_called()
 
-     def test_poll_raises_ValueError(self):
 
-         x = X(self.app)
 
-         x.hub.readers = {6: Mock()}
 
-         poller = x.hub.poller
 
-         x.close_then_error(poller.poll, exc=ValueError)
 
-         asynloop(*x.args)
 
-         poller.poll.assert_called()
 
- class test_synloop:
 
-     def test_timeout_ignored(self):
 
-         x = X(self.app)
 
-         x.timeout_then_error(x.connection.drain_events)
 
-         with pytest.raises(socket.error):
 
-             synloop(*x.args)
 
-         assert x.connection.drain_events.call_count == 2
 
-     def test_updates_qos_when_changed(self):
 
-         x = X(self.app)
 
-         x.qos.prev = 2
 
-         x.qos.value = 2
 
-         x.timeout_then_error(x.connection.drain_events)
 
-         with pytest.raises(socket.error):
 
-             synloop(*x.args)
 
-         x.qos.update.assert_not_called()
 
-         x.qos.value = 4
 
-         x.timeout_then_error(x.connection.drain_events)
 
-         with pytest.raises(socket.error):
 
-             synloop(*x.args)
 
-         x.qos.update.assert_called_with()
 
-     def test_ignores_socket_errors_when_closed(self):
 
-         x = X(self.app)
 
-         x.close_then_error(x.connection.drain_events)
 
-         assert synloop(*x.args) is None
 
- class test_quick_drain:
 
-     def setup(self):
 
-         self.connection = Mock(name='connection')
 
-     def test_drain(self):
 
-         _quick_drain(self.connection, timeout=33.3)
 
-         self.connection.drain_events.assert_called_with(timeout=33.3)
 
-     def test_drain_error(self):
 
-         exc = KeyError()
 
-         exc.errno = 313
 
-         self.connection.drain_events.side_effect = exc
 
-         with pytest.raises(KeyError):
 
-             _quick_drain(self.connection, timeout=33.3)
 
-     def test_drain_error_EAGAIN(self):
 
-         exc = KeyError()
 
-         exc.errno = errno.EAGAIN
 
-         self.connection.drain_events.side_effect = exc
 
-         _quick_drain(self.connection, timeout=33.3)
 
 
  |