|
@@ -3,6 +3,7 @@ from __future__ import absolute_import
|
|
|
import errno
|
|
|
import socket
|
|
|
|
|
|
+from amqp import promise
|
|
|
from kombu.async import Hub, READ, WRITE, ERR
|
|
|
|
|
|
from celery.bootsteps import CLOSE, RUN
|
|
@@ -18,6 +19,22 @@ from celery.worker.loops import _quick_drain, asynloop, synloop
|
|
|
from celery.tests.case import AppCase, Mock, task_message_from_sig
|
|
|
|
|
|
|
|
|
+class PromiseEqual(object):
|
|
|
+
|
|
|
+ def __init__(self, fun, *args, **kwargs):
|
|
|
+ self.fun = fun
|
|
|
+ self.args = args
|
|
|
+ self.kwargs = kwargs
|
|
|
+
|
|
|
+ def __eq__(self, other):
|
|
|
+ return (other.fun == self.fun and
|
|
|
+ other.args == self.args and
|
|
|
+ other.kwargs == self.kwargs)
|
|
|
+
|
|
|
+ def __repr__(self):
|
|
|
+ return '<promise: {0.fun!r} {0.args!r} {0.kwargs!r}>'.format(self)
|
|
|
+
|
|
|
+
|
|
|
class X(object):
|
|
|
|
|
|
def __init__(self, app, heartbeat=None, on_task_message=None,
|
|
@@ -61,7 +78,8 @@ class X(object):
|
|
|
self.Hub = self.hub
|
|
|
self.blueprint.state = RUN
|
|
|
# need this for create_task_handler
|
|
|
- _consumer = Consumer(Mock(), timer=Mock(), controller=Mock(), app=app)
|
|
|
+ self._consumer = _consumer = Consumer(
|
|
|
+ Mock(), timer=Mock(), controller=Mock(), app=app)
|
|
|
_consumer.on_task_message = on_task_message or []
|
|
|
self.obj.create_task_handler = _consumer.create_task_handler
|
|
|
self.on_unknown_message = self.obj.on_unknown_message = Mock(
|
|
@@ -157,20 +175,25 @@ class test_asynloop(AppCase):
|
|
|
return x, on_task, message, strategy
|
|
|
|
|
|
def test_on_task_received(self):
|
|
|
- _, on_task, msg, strategy = self.task_context(self.add.s(2, 2))
|
|
|
+ x, on_task, msg, strategy = self.task_context(self.add.s(2, 2))
|
|
|
on_task(msg)
|
|
|
strategy.assert_called_with(
|
|
|
- msg, None, msg.ack_log_error, msg.reject_log_error, [],
|
|
|
+ msg, None,
|
|
|
+ PromiseEqual(x._consumer.call_soon, msg.ack_log_error),
|
|
|
+ PromiseEqual(x._consumer.call_soon, msg.reject_log_error), [],
|
|
|
)
|
|
|
|
|
|
def test_on_task_received_executes_on_task_message(self):
|
|
|
cbs = [Mock(), Mock(), Mock()]
|
|
|
- _, on_task, msg, strategy = self.task_context(
|
|
|
+ x, on_task, msg, strategy = self.task_context(
|
|
|
self.add.s(2, 2), on_task_message=cbs,
|
|
|
)
|
|
|
on_task(msg)
|
|
|
strategy.assert_called_with(
|
|
|
- msg, None, msg.ack_log_error, msg.reject_log_error, cbs,
|
|
|
+ msg, None,
|
|
|
+ PromiseEqual(x._consumer.call_soon, msg.ack_log_error),
|
|
|
+ PromiseEqual(x._consumer.call_soon, msg.reject_log_error),
|
|
|
+ cbs,
|
|
|
)
|
|
|
|
|
|
def test_on_task_message_missing_name(self):
|