|  | @@ -3,6 +3,7 @@ from __future__ import absolute_import
 | 
	
		
			
				|  |  |  import errno
 | 
	
		
			
				|  |  |  import socket
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +from amqp import promise
 | 
	
		
			
				|  |  |  from kombu.async import Hub, READ, WRITE, ERR
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  from celery.bootsteps import CLOSE, RUN
 | 
	
	
		
			
				|  | @@ -18,6 +19,22 @@ from celery.worker.loops import _quick_drain, asynloop, synloop
 | 
	
		
			
				|  |  |  from celery.tests.case import AppCase, Mock, task_message_from_sig
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +class PromiseEqual(object):
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def __init__(self, fun, *args, **kwargs):
 | 
	
		
			
				|  |  | +        self.fun = fun
 | 
	
		
			
				|  |  | +        self.args = args
 | 
	
		
			
				|  |  | +        self.kwargs = kwargs
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def __eq__(self, other):
 | 
	
		
			
				|  |  | +        return (other.fun == self.fun and
 | 
	
		
			
				|  |  | +                other.args == self.args and
 | 
	
		
			
				|  |  | +                other.kwargs == self.kwargs)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def __repr__(self):
 | 
	
		
			
				|  |  | +        return '<promise: {0.fun!r} {0.args!r} {0.kwargs!r}>'.format(self)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  class X(object):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def __init__(self, app, heartbeat=None, on_task_message=None,
 | 
	
	
		
			
				|  | @@ -61,7 +78,8 @@ class X(object):
 | 
	
		
			
				|  |  |          self.Hub = self.hub
 | 
	
		
			
				|  |  |          self.blueprint.state = RUN
 | 
	
		
			
				|  |  |          # need this for create_task_handler
 | 
	
		
			
				|  |  | -        _consumer = Consumer(Mock(), timer=Mock(), controller=Mock(), app=app)
 | 
	
		
			
				|  |  | +        self._consumer = _consumer = Consumer(
 | 
	
		
			
				|  |  | +            Mock(), timer=Mock(), controller=Mock(), app=app)
 | 
	
		
			
				|  |  |          _consumer.on_task_message = on_task_message or []
 | 
	
		
			
				|  |  |          self.obj.create_task_handler = _consumer.create_task_handler
 | 
	
		
			
				|  |  |          self.on_unknown_message = self.obj.on_unknown_message = Mock(
 | 
	
	
		
			
				|  | @@ -157,20 +175,25 @@ class test_asynloop(AppCase):
 | 
	
		
			
				|  |  |          return x, on_task, message, strategy
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def test_on_task_received(self):
 | 
	
		
			
				|  |  | -        _, on_task, msg, strategy = self.task_context(self.add.s(2, 2))
 | 
	
		
			
				|  |  | +        x, on_task, msg, strategy = self.task_context(self.add.s(2, 2))
 | 
	
		
			
				|  |  |          on_task(msg)
 | 
	
		
			
				|  |  |          strategy.assert_called_with(
 | 
	
		
			
				|  |  | -            msg, None, msg.ack_log_error, msg.reject_log_error, [],
 | 
	
		
			
				|  |  | +            msg, None,
 | 
	
		
			
				|  |  | +            PromiseEqual(x._consumer.call_soon, msg.ack_log_error),
 | 
	
		
			
				|  |  | +            PromiseEqual(x._consumer.call_soon, msg.reject_log_error), [],
 | 
	
		
			
				|  |  |          )
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def test_on_task_received_executes_on_task_message(self):
 | 
	
		
			
				|  |  |          cbs = [Mock(), Mock(), Mock()]
 | 
	
		
			
				|  |  | -        _, on_task, msg, strategy = self.task_context(
 | 
	
		
			
				|  |  | +        x, on_task, msg, strategy = self.task_context(
 | 
	
		
			
				|  |  |              self.add.s(2, 2), on_task_message=cbs,
 | 
	
		
			
				|  |  |          )
 | 
	
		
			
				|  |  |          on_task(msg)
 | 
	
		
			
				|  |  |          strategy.assert_called_with(
 | 
	
		
			
				|  |  | -            msg, None, msg.ack_log_error, msg.reject_log_error, cbs,
 | 
	
		
			
				|  |  | +            msg, None,
 | 
	
		
			
				|  |  | +            PromiseEqual(x._consumer.call_soon, msg.ack_log_error),
 | 
	
		
			
				|  |  | +            PromiseEqual(x._consumer.call_soon, msg.reject_log_error),
 | 
	
		
			
				|  |  | +            cbs,
 | 
	
		
			
				|  |  |          )
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def test_on_task_message_missing_name(self):
 |