|
@@ -16,7 +16,8 @@ from celery.tests.case import AppCase, Mock, body_from_sig
|
|
|
|
|
|
class X(object):
|
|
|
|
|
|
- def __init__(self, app, heartbeat=None, on_task_message=None):
|
|
|
+ def __init__(self, app, heartbeat=None, on_task_message=None,
|
|
|
+ transport_driver_type=None):
|
|
|
hub = Hub()
|
|
|
(
|
|
|
self.obj,
|
|
@@ -42,6 +43,8 @@ class X(object):
|
|
|
self.consumer.callbacks = []
|
|
|
self.obj.strategies = {}
|
|
|
self.connection.connection_errors = (socket.error, )
|
|
|
+ if transport_driver_type:
|
|
|
+ self.connection.transport.driver_type = transport_driver_type
|
|
|
self.hub.readers = {}
|
|
|
self.hub.writers = {}
|
|
|
self.hub.consolidate = set()
|
|
@@ -120,8 +123,10 @@ class test_asynloop(AppCase):
|
|
|
self.add = add
|
|
|
|
|
|
def test_drain_after_consume(self):
|
|
|
- x, _ = get_task_callback(self.app)
|
|
|
- x.connection.drain_events.assert_called_with()
|
|
|
+ x, _ = get_task_callback(self.app, transport_driver_type='amqp')
|
|
|
+ self.assertIn(
|
|
|
+ x.connection.drain_events, [p.fun for p in x.hub._ready],
|
|
|
+ )
|
|
|
|
|
|
def test_setup_heartbeat(self):
|
|
|
x = X(self.app, heartbeat=10)
|