|  | @@ -5,18 +5,21 @@ import socket
 | 
	
		
			
				|  |  |  from collections import defaultdict
 | 
	
		
			
				|  |  |  from mock import Mock
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +from kombu.async import Hub, READ, WRITE, ERR
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  from celery.exceptions import InvalidTaskError, SystemTerminate
 | 
	
		
			
				|  |  |  from celery.five import Empty
 | 
	
		
			
				|  |  |  from celery.worker import state
 | 
	
		
			
				|  |  |  from celery.worker.consumer import Consumer
 | 
	
		
			
				|  |  | -from celery.worker.loops import asynloop, synloop, CLOSE, READ, WRITE, ERR
 | 
	
		
			
				|  |  | +from celery.worker.loops import asynloop, synloop, CLOSE
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  from celery.tests.case import AppCase, body_from_sig
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  class X(object):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    def __init__(self, app, heartbeat=None, on_task=None):
 | 
	
		
			
				|  |  | +    def __init__(self, app, heartbeat=None, on_task_message=None):
 | 
	
		
			
				|  |  | +        hub = Hub()
 | 
	
		
			
				|  |  |          (
 | 
	
		
			
				|  |  |              self.obj,
 | 
	
		
			
				|  |  |              self.connection,
 | 
	
	
		
			
				|  | @@ -30,7 +33,7 @@ class X(object):
 | 
	
		
			
				|  |  |                           Mock(name='connection'),
 | 
	
		
			
				|  |  |                           Mock(name='consumer'),
 | 
	
		
			
				|  |  |                           Mock(name='blueprint'),
 | 
	
		
			
				|  |  | -                         Mock(name='Hub'),
 | 
	
		
			
				|  |  | +                         hub,
 | 
	
		
			
				|  |  |                           Mock(name='qos'),
 | 
	
		
			
				|  |  |                           heartbeat,
 | 
	
		
			
				|  |  |                           Mock(name='clock')]
 | 
	
	
		
			
				|  | @@ -38,16 +41,19 @@ class X(object):
 | 
	
		
			
				|  |  |          self.consumer.callbacks = []
 | 
	
		
			
				|  |  |          self.obj.strategies = {}
 | 
	
		
			
				|  |  |          self.connection.connection_errors = (socket.error, )
 | 
	
		
			
				|  |  | -        #hent = self.Hub.__enter__ = Mock(name='Hub.__enter__')
 | 
	
		
			
				|  |  | -        #self.Hub.__exit__ = Mock(name='Hub.__exit__')
 | 
	
		
			
				|  |  | -        #self.hub = hent.return_value = Mock(name='hub_context')
 | 
	
		
			
				|  |  |          self.hub.readers = {}
 | 
	
		
			
				|  |  |          self.hub.writers = {}
 | 
	
		
			
				|  |  |          self.hub.consolidate = set()
 | 
	
		
			
				|  |  | +        self.hub.timer = Mock(name='hub.timer')
 | 
	
		
			
				|  |  | +        self.hub.timer._queue = [Mock()]
 | 
	
		
			
				|  |  | +        self.hub.fire_timers = Mock(name='hub.fire_timers')
 | 
	
		
			
				|  |  |          self.hub.fire_timers.return_value = 1.7
 | 
	
		
			
				|  |  | +        self.hub.poller = Mock(name='hub.poller')
 | 
	
		
			
				|  |  | +        self.hub.close = Mock(name='hub.close()')  # asynloop calls hub.close
 | 
	
		
			
				|  |  |          self.Hub = self.hub
 | 
	
		
			
				|  |  |          # need this for create_task_handler
 | 
	
		
			
				|  |  |          _consumer = Consumer(Mock(), timer=Mock(), app=app)
 | 
	
		
			
				|  |  | +        _consumer.on_task_message = on_task_message or []
 | 
	
		
			
				|  |  |          self.obj.create_task_handler = _consumer.create_task_handler
 | 
	
		
			
				|  |  |          self.on_unknown_message = self.obj.on_unknown_message = Mock(
 | 
	
		
			
				|  |  |              name='on_unknown_message',
 | 
	
	
		
			
				|  | @@ -71,21 +77,27 @@ class X(object):
 | 
	
		
			
				|  |  |              raise socket.timeout()
 | 
	
		
			
				|  |  |          mock.side_effect = first
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    def close_then_error(self, mock, mod=0):
 | 
	
		
			
				|  |  | +    def close_then_error(self, mock=None, mod=0, exc=None):
 | 
	
		
			
				|  |  | +        mock = Mock() if mock is None else mock
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          def first(*args, **kwargs):
 | 
	
		
			
				|  |  |              if not mod or mock.call_count > mod:
 | 
	
		
			
				|  |  |                  self.close()
 | 
	
		
			
				|  |  |                  self.connection.more_to_read = False
 | 
	
		
			
				|  |  | -                raise socket.error()
 | 
	
		
			
				|  |  | +                raise (socket.error() if exc is None else exc)
 | 
	
		
			
				|  |  |          mock.side_effect = first
 | 
	
		
			
				|  |  | +        return mock
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def close(self, *args, **kwargs):
 | 
	
		
			
				|  |  |          self.blueprint.state = CLOSE
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    def closer(self, mock=None):
 | 
	
		
			
				|  |  | +    def closer(self, mock=None, mod=0):
 | 
	
		
			
				|  |  |          mock = Mock() if mock is None else mock
 | 
	
		
			
				|  |  | -        mock.side_effect = self.close
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        def closing(*args, **kwargs):
 | 
	
		
			
				|  |  | +            if not mod or mock.call_count >= mod:
 | 
	
		
			
				|  |  | +                self.close()
 | 
	
		
			
				|  |  | +        mock.side_effect = closing
 | 
	
		
			
				|  |  |          return mock
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -107,12 +119,13 @@ class test_asynloop(AppCase):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def test_setup_heartbeat(self):
 | 
	
		
			
				|  |  |          x = X(self.app, heartbeat=10)
 | 
	
		
			
				|  |  | +        x.hub.call_repeatedly = Mock(name='x.hub.call_repeatedly()')
 | 
	
		
			
				|  |  |          x.blueprint.state = CLOSE
 | 
	
		
			
				|  |  |          asynloop(*x.args)
 | 
	
		
			
				|  |  |          x.consumer.consume.assert_called_with()
 | 
	
		
			
				|  |  |          x.obj.on_ready.assert_called_with()
 | 
	
		
			
				|  |  |          x.hub.call_repeatedly.assert_called_with(
 | 
	
		
			
				|  |  | -            10 / 2.0, x.connection.heartbeat_check, (2.0, ),
 | 
	
		
			
				|  |  | +            10 / 2.0, x.connection.heartbeat_check, 2.0,
 | 
	
		
			
				|  |  |          )
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def task_context(self, sig, **kwargs):
 | 
	
	
		
			
				|  | @@ -127,10 +140,10 @@ class test_asynloop(AppCase):
 | 
	
		
			
				|  |  |          on_task(body, msg)
 | 
	
		
			
				|  |  |          strategy.assert_called_with(msg, body, msg.ack_log_error)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    def test_on_task_received_executes_hub_on_task(self):
 | 
	
		
			
				|  |  | +    def test_on_task_received_executes_on_task_message(self):
 | 
	
		
			
				|  |  |          cbs = [Mock(), Mock(), Mock()]
 | 
	
		
			
				|  |  |          _, on_task, body, msg, _ = self.task_context(
 | 
	
		
			
				|  |  | -            self.add.s(2, 2), on_task=cbs,
 | 
	
		
			
				|  |  | +            self.add.s(2, 2), on_task_message=cbs,
 | 
	
		
			
				|  |  |          )
 | 
	
		
			
				|  |  |          on_task(body, msg)
 | 
	
		
			
				|  |  |          [cb.assert_called_with() for cb in cbs]
 | 
	
	
		
			
				|  | @@ -187,20 +200,24 @@ class test_asynloop(AppCase):
 | 
	
		
			
				|  |  |          x = X(self.app)
 | 
	
		
			
				|  |  |          x.qos.prev = 3
 | 
	
		
			
				|  |  |          x.qos.value = 3
 | 
	
		
			
				|  |  | -        asynloop(*x.args, sleep=x.closer())
 | 
	
		
			
				|  |  | +        x.hub.on_tick.add(x.closer(mod=2))
 | 
	
		
			
				|  |  | +        x.hub.timer._queue = [1]
 | 
	
		
			
				|  |  | +        asynloop(*x.args)
 | 
	
		
			
				|  |  |          self.assertFalse(x.qos.update.called)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          x = X(self.app)
 | 
	
		
			
				|  |  |          x.qos.prev = 1
 | 
	
		
			
				|  |  |          x.qos.value = 6
 | 
	
		
			
				|  |  | -        asynloop(*x.args, sleep=x.closer())
 | 
	
		
			
				|  |  | +        x.hub.on_tick.add(x.closer(mod=2))
 | 
	
		
			
				|  |  | +        asynloop(*x.args)
 | 
	
		
			
				|  |  |          x.qos.update.assert_called_with()
 | 
	
		
			
				|  |  |          x.hub.fire_timers.assert_called_with(propagate=(socket.error, ))
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def test_poll_empty(self):
 | 
	
		
			
				|  |  |          x = X(self.app)
 | 
	
		
			
				|  |  |          x.hub.readers = {6: Mock()}
 | 
	
		
			
				|  |  | -        x.close_then_error(x.connection.drain_nowait)
 | 
	
		
			
				|  |  | +        x.hub.timer._queue = [1]
 | 
	
		
			
				|  |  | +        x.close_then_error(x.hub.poller.poll)
 | 
	
		
			
				|  |  |          x.hub.fire_timers.return_value = 33.37
 | 
	
		
			
				|  |  |          x.hub.poller.poll.return_value = []
 | 
	
		
			
				|  |  |          with self.assertRaises(socket.error):
 | 
	
	
		
			
				|  | @@ -209,39 +226,43 @@ class test_asynloop(AppCase):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def test_poll_readable(self):
 | 
	
		
			
				|  |  |          x = X(self.app)
 | 
	
		
			
				|  |  | -        x.hub.readers = {6: Mock()}
 | 
	
		
			
				|  |  | -        x.close_then_error(x.connection.drain_nowait, mod=4)
 | 
	
		
			
				|  |  | +        reader = Mock(name='reader')
 | 
	
		
			
				|  |  | +        x.hub.add_reader(6, reader)
 | 
	
		
			
				|  |  | +        x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), mod=4))
 | 
	
		
			
				|  |  |          x.hub.poller.poll.return_value = [(6, READ)]
 | 
	
		
			
				|  |  |          with self.assertRaises(socket.error):
 | 
	
		
			
				|  |  |              asynloop(*x.args)
 | 
	
		
			
				|  |  | -        x.hub.readers[6].assert_called_with(6, READ)
 | 
	
		
			
				|  |  | +        reader.assert_called_with(6, READ)
 | 
	
		
			
				|  |  |          self.assertTrue(x.hub.poller.poll.called)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def test_poll_readable_raises_Empty(self):
 | 
	
		
			
				|  |  |          x = X(self.app)
 | 
	
		
			
				|  |  | -        x.hub.readers = {6: Mock()}
 | 
	
		
			
				|  |  | -        x.close_then_error(x.connection.drain_nowait)
 | 
	
		
			
				|  |  | +        reader = Mock(name='reader')
 | 
	
		
			
				|  |  | +        x.hub.add_reader(6, reader)
 | 
	
		
			
				|  |  | +        x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
 | 
	
		
			
				|  |  |          x.hub.poller.poll.return_value = [(6, READ)]
 | 
	
		
			
				|  |  | -        x.hub.readers[6].side_effect = Empty()
 | 
	
		
			
				|  |  | +        reader.side_effect = Empty()
 | 
	
		
			
				|  |  |          with self.assertRaises(socket.error):
 | 
	
		
			
				|  |  |              asynloop(*x.args)
 | 
	
		
			
				|  |  | -        x.hub.readers[6].assert_called_with(6, READ)
 | 
	
		
			
				|  |  | +        reader.assert_called_with(6, READ)
 | 
	
		
			
				|  |  |          self.assertTrue(x.hub.poller.poll.called)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def test_poll_writable(self):
 | 
	
		
			
				|  |  |          x = X(self.app)
 | 
	
		
			
				|  |  | -        x.hub.writers = {6: Mock()}
 | 
	
		
			
				|  |  | -        x.close_then_error(x.connection.drain_nowait)
 | 
	
		
			
				|  |  | +        writer = Mock(name='writer')
 | 
	
		
			
				|  |  | +        x.hub.add_writer(6, writer)
 | 
	
		
			
				|  |  | +        x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
 | 
	
		
			
				|  |  |          x.hub.poller.poll.return_value = [(6, WRITE)]
 | 
	
		
			
				|  |  |          with self.assertRaises(socket.error):
 | 
	
		
			
				|  |  |              asynloop(*x.args)
 | 
	
		
			
				|  |  | -        x.hub.writers[6].assert_called_with(6, WRITE)
 | 
	
		
			
				|  |  | +        writer.assert_called_with(6, WRITE)
 | 
	
		
			
				|  |  |          self.assertTrue(x.hub.poller.poll.called)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def test_poll_writable_none_registered(self):
 | 
	
		
			
				|  |  |          x = X(self.app)
 | 
	
		
			
				|  |  | -        x.hub.writers = {6: Mock()}
 | 
	
		
			
				|  |  | -        x.close_then_error(x.connection.drain_nowait)
 | 
	
		
			
				|  |  | +        writer = Mock(name='writer')
 | 
	
		
			
				|  |  | +        x.hub.add_writer(6, writer)
 | 
	
		
			
				|  |  | +        x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
 | 
	
		
			
				|  |  |          x.hub.poller.poll.return_value = [(7, WRITE)]
 | 
	
		
			
				|  |  |          with self.assertRaises(socket.error):
 | 
	
		
			
				|  |  |              asynloop(*x.args)
 | 
	
	
		
			
				|  | @@ -249,8 +270,9 @@ class test_asynloop(AppCase):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def test_poll_unknown_event(self):
 | 
	
		
			
				|  |  |          x = X(self.app)
 | 
	
		
			
				|  |  | -        x.hub.writers = {6: Mock()}
 | 
	
		
			
				|  |  | -        x.close_then_error(x.connection.drain_nowait)
 | 
	
		
			
				|  |  | +        writer = Mock(name='reader')
 | 
	
		
			
				|  |  | +        x.hub.add_writer(6, writer)
 | 
	
		
			
				|  |  | +        x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
 | 
	
		
			
				|  |  |          x.hub.poller.poll.return_value = [(6, 0)]
 | 
	
		
			
				|  |  |          with self.assertRaises(socket.error):
 | 
	
		
			
				|  |  |              asynloop(*x.args)
 | 
	
	
		
			
				|  | @@ -275,24 +297,26 @@ class test_asynloop(AppCase):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def test_poll_err_writable(self):
 | 
	
		
			
				|  |  |          x = X(self.app)
 | 
	
		
			
				|  |  | -        x.hub.writers = {6: Mock()}
 | 
	
		
			
				|  |  | -        x.close_then_error(x.connection.drain_nowait)
 | 
	
		
			
				|  |  | +        writer = Mock(name='writer')
 | 
	
		
			
				|  |  | +        x.hub.add_writer(6, writer, 48)
 | 
	
		
			
				|  |  | +        x.hub.on_tick.add(x.close_then_error(Mock(), 2))
 | 
	
		
			
				|  |  |          x.hub.poller.poll.return_value = [(6, ERR)]
 | 
	
		
			
				|  |  |          with self.assertRaises(socket.error):
 | 
	
		
			
				|  |  |              asynloop(*x.args)
 | 
	
		
			
				|  |  | -        x.hub.writers[6].assert_called_with(6, ERR)
 | 
	
		
			
				|  |  | +        writer.assert_called_with(6, ERR, 48)
 | 
	
		
			
				|  |  |          self.assertTrue(x.hub.poller.poll.called)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def test_poll_write_generator(self):
 | 
	
		
			
				|  |  |          x = X(self.app)
 | 
	
		
			
				|  |  | +        x.hub.remove = Mock(name='hub.remove()')
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          def Gen():
 | 
	
		
			
				|  |  |              yield 1
 | 
	
		
			
				|  |  |              yield 2
 | 
	
		
			
				|  |  |          gen = Gen()
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        x.hub.writers = {6: gen}
 | 
	
		
			
				|  |  | -        x.close_then_error(x.connection.drain_nowait)
 | 
	
		
			
				|  |  | +        x.hub.add_writer(6, gen)
 | 
	
		
			
				|  |  | +        x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
 | 
	
		
			
				|  |  |          x.hub.poller.poll.return_value = [(6, WRITE)]
 | 
	
		
			
				|  |  |          with self.assertRaises(socket.error):
 | 
	
		
			
				|  |  |              asynloop(*x.args)
 | 
	
	
		
			
				|  | @@ -306,9 +330,10 @@ class test_asynloop(AppCase):
 | 
	
		
			
				|  |  |              raise StopIteration()
 | 
	
		
			
				|  |  |              yield
 | 
	
		
			
				|  |  |          gen = Gen()
 | 
	
		
			
				|  |  | -        x.hub.writers = {6: gen}
 | 
	
		
			
				|  |  | -        x.close_then_error(x.connection.drain_nowait)
 | 
	
		
			
				|  |  | +        x.hub.add_writer(6, gen)
 | 
	
		
			
				|  |  | +        x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
 | 
	
		
			
				|  |  |          x.hub.poller.poll.return_value = [(6, WRITE)]
 | 
	
		
			
				|  |  | +        x.hub.remove = Mock(name='hub.remove()')
 | 
	
		
			
				|  |  |          with self.assertRaises(socket.error):
 | 
	
		
			
				|  |  |              asynloop(*x.args)
 | 
	
		
			
				|  |  |          self.assertIsNone(gen.gi_frame)
 | 
	
	
		
			
				|  | @@ -321,8 +346,9 @@ class test_asynloop(AppCase):
 | 
	
		
			
				|  |  |              raise ValueError('foo')
 | 
	
		
			
				|  |  |              yield
 | 
	
		
			
				|  |  |          gen = Gen()
 | 
	
		
			
				|  |  | -        x.hub.writers = {6: gen}
 | 
	
		
			
				|  |  | -        x.close_then_error(x.connection.drain_nowait)
 | 
	
		
			
				|  |  | +        x.hub.add_writer(6, gen)
 | 
	
		
			
				|  |  | +        x.hub.remove = Mock(name='hub.remove()')
 | 
	
		
			
				|  |  | +        x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
 | 
	
		
			
				|  |  |          x.hub.poller.poll.return_value = [(6, WRITE)]
 | 
	
		
			
				|  |  |          with self.assertRaises(ValueError):
 | 
	
		
			
				|  |  |              asynloop(*x.args)
 | 
	
	
		
			
				|  | @@ -331,19 +357,19 @@ class test_asynloop(AppCase):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def test_poll_err_readable(self):
 | 
	
		
			
				|  |  |          x = X(self.app)
 | 
	
		
			
				|  |  | -        x.hub.readers = {6: Mock()}
 | 
	
		
			
				|  |  | -        x.close_then_error(x.connection.drain_nowait)
 | 
	
		
			
				|  |  | +        reader = Mock(name='reader')
 | 
	
		
			
				|  |  | +        x.hub.add_reader(6, reader, 24)
 | 
	
		
			
				|  |  | +        x.hub.on_tick.add(x.close_then_error(Mock(), 2))
 | 
	
		
			
				|  |  |          x.hub.poller.poll.return_value = [(6, ERR)]
 | 
	
		
			
				|  |  |          with self.assertRaises(socket.error):
 | 
	
		
			
				|  |  |              asynloop(*x.args)
 | 
	
		
			
				|  |  | -        x.hub.readers[6].assert_called_with(6, ERR)
 | 
	
		
			
				|  |  | +        reader.assert_called_with(6, ERR, 24)
 | 
	
		
			
				|  |  |          self.assertTrue(x.hub.poller.poll.called)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def test_poll_raises_ValueError(self):
 | 
	
		
			
				|  |  |          x = X(self.app)
 | 
	
		
			
				|  |  |          x.hub.readers = {6: Mock()}
 | 
	
		
			
				|  |  | -        x.close_then_error(x.connection.drain_nowait)
 | 
	
		
			
				|  |  | -        x.hub.poller.poll.side_effect = ValueError()
 | 
	
		
			
				|  |  | +        x.close_then_error(x.hub.poller.poll, exc=ValueError)
 | 
	
		
			
				|  |  |          asynloop(*x.args)
 | 
	
		
			
				|  |  |          self.assertTrue(x.hub.poller.poll.called)
 | 
	
		
			
				|  |  |  
 |