|
@@ -5,18 +5,21 @@ import socket
|
|
|
from collections import defaultdict
|
|
|
from mock import Mock
|
|
|
|
|
|
+from kombu.async import Hub, READ, WRITE, ERR
|
|
|
+
|
|
|
from celery.exceptions import InvalidTaskError, SystemTerminate
|
|
|
from celery.five import Empty
|
|
|
from celery.worker import state
|
|
|
from celery.worker.consumer import Consumer
|
|
|
-from celery.worker.loops import asynloop, synloop, CLOSE, READ, WRITE, ERR
|
|
|
+from celery.worker.loops import asynloop, synloop, CLOSE
|
|
|
|
|
|
from celery.tests.case import AppCase, body_from_sig
|
|
|
|
|
|
|
|
|
class X(object):
|
|
|
|
|
|
- def __init__(self, app, heartbeat=None, on_task=None):
|
|
|
+ def __init__(self, app, heartbeat=None, on_task_message=None):
|
|
|
+ hub = Hub()
|
|
|
(
|
|
|
self.obj,
|
|
|
self.connection,
|
|
@@ -30,7 +33,7 @@ class X(object):
|
|
|
Mock(name='connection'),
|
|
|
Mock(name='consumer'),
|
|
|
Mock(name='blueprint'),
|
|
|
- Mock(name='Hub'),
|
|
|
+ hub,
|
|
|
Mock(name='qos'),
|
|
|
heartbeat,
|
|
|
Mock(name='clock')]
|
|
@@ -38,16 +41,19 @@ class X(object):
|
|
|
self.consumer.callbacks = []
|
|
|
self.obj.strategies = {}
|
|
|
self.connection.connection_errors = (socket.error, )
|
|
|
- #hent = self.Hub.__enter__ = Mock(name='Hub.__enter__')
|
|
|
- #self.Hub.__exit__ = Mock(name='Hub.__exit__')
|
|
|
- #self.hub = hent.return_value = Mock(name='hub_context')
|
|
|
self.hub.readers = {}
|
|
|
self.hub.writers = {}
|
|
|
self.hub.consolidate = set()
|
|
|
+ self.hub.timer = Mock(name='hub.timer')
|
|
|
+ self.hub.timer._queue = [Mock()]
|
|
|
+ self.hub.fire_timers = Mock(name='hub.fire_timers')
|
|
|
self.hub.fire_timers.return_value = 1.7
|
|
|
+ self.hub.poller = Mock(name='hub.poller')
|
|
|
+ self.hub.close = Mock(name='hub.close()') # asynloop calls hub.close
|
|
|
self.Hub = self.hub
|
|
|
# need this for create_task_handler
|
|
|
_consumer = Consumer(Mock(), timer=Mock(), app=app)
|
|
|
+ _consumer.on_task_message = on_task_message or []
|
|
|
self.obj.create_task_handler = _consumer.create_task_handler
|
|
|
self.on_unknown_message = self.obj.on_unknown_message = Mock(
|
|
|
name='on_unknown_message',
|
|
@@ -71,21 +77,27 @@ class X(object):
|
|
|
raise socket.timeout()
|
|
|
mock.side_effect = first
|
|
|
|
|
|
- def close_then_error(self, mock, mod=0):
|
|
|
+ def close_then_error(self, mock=None, mod=0, exc=None):
|
|
|
+ mock = Mock() if mock is None else mock
|
|
|
|
|
|
def first(*args, **kwargs):
|
|
|
if not mod or mock.call_count > mod:
|
|
|
self.close()
|
|
|
self.connection.more_to_read = False
|
|
|
- raise socket.error()
|
|
|
+ raise (socket.error() if exc is None else exc)
|
|
|
mock.side_effect = first
|
|
|
+ return mock
|
|
|
|
|
|
def close(self, *args, **kwargs):
|
|
|
self.blueprint.state = CLOSE
|
|
|
|
|
|
- def closer(self, mock=None):
|
|
|
+ def closer(self, mock=None, mod=0):
|
|
|
mock = Mock() if mock is None else mock
|
|
|
- mock.side_effect = self.close
|
|
|
+
|
|
|
+ def closing(*args, **kwargs):
|
|
|
+ if not mod or mock.call_count >= mod:
|
|
|
+ self.close()
|
|
|
+ mock.side_effect = closing
|
|
|
return mock
|
|
|
|
|
|
|
|
@@ -107,12 +119,13 @@ class test_asynloop(AppCase):
|
|
|
|
|
|
def test_setup_heartbeat(self):
|
|
|
x = X(self.app, heartbeat=10)
|
|
|
+ x.hub.call_repeatedly = Mock(name='x.hub.call_repeatedly()')
|
|
|
x.blueprint.state = CLOSE
|
|
|
asynloop(*x.args)
|
|
|
x.consumer.consume.assert_called_with()
|
|
|
x.obj.on_ready.assert_called_with()
|
|
|
x.hub.call_repeatedly.assert_called_with(
|
|
|
- 10 / 2.0, x.connection.heartbeat_check, (2.0, ),
|
|
|
+ 10 / 2.0, x.connection.heartbeat_check, 2.0,
|
|
|
)
|
|
|
|
|
|
def task_context(self, sig, **kwargs):
|
|
@@ -127,10 +140,10 @@ class test_asynloop(AppCase):
|
|
|
on_task(body, msg)
|
|
|
strategy.assert_called_with(msg, body, msg.ack_log_error)
|
|
|
|
|
|
- def test_on_task_received_executes_hub_on_task(self):
|
|
|
+ def test_on_task_received_executes_on_task_message(self):
|
|
|
cbs = [Mock(), Mock(), Mock()]
|
|
|
_, on_task, body, msg, _ = self.task_context(
|
|
|
- self.add.s(2, 2), on_task=cbs,
|
|
|
+ self.add.s(2, 2), on_task_message=cbs,
|
|
|
)
|
|
|
on_task(body, msg)
|
|
|
[cb.assert_called_with() for cb in cbs]
|
|
@@ -187,20 +200,24 @@ class test_asynloop(AppCase):
|
|
|
x = X(self.app)
|
|
|
x.qos.prev = 3
|
|
|
x.qos.value = 3
|
|
|
- asynloop(*x.args, sleep=x.closer())
|
|
|
+ x.hub.on_tick.add(x.closer(mod=2))
|
|
|
+ x.hub.timer._queue = [1]
|
|
|
+ asynloop(*x.args)
|
|
|
self.assertFalse(x.qos.update.called)
|
|
|
|
|
|
x = X(self.app)
|
|
|
x.qos.prev = 1
|
|
|
x.qos.value = 6
|
|
|
- asynloop(*x.args, sleep=x.closer())
|
|
|
+ x.hub.on_tick.add(x.closer(mod=2))
|
|
|
+ asynloop(*x.args)
|
|
|
x.qos.update.assert_called_with()
|
|
|
x.hub.fire_timers.assert_called_with(propagate=(socket.error, ))
|
|
|
|
|
|
def test_poll_empty(self):
|
|
|
x = X(self.app)
|
|
|
x.hub.readers = {6: Mock()}
|
|
|
- x.close_then_error(x.connection.drain_nowait)
|
|
|
+ x.hub.timer._queue = [1]
|
|
|
+ x.close_then_error(x.hub.poller.poll)
|
|
|
x.hub.fire_timers.return_value = 33.37
|
|
|
x.hub.poller.poll.return_value = []
|
|
|
with self.assertRaises(socket.error):
|
|
@@ -209,39 +226,43 @@ class test_asynloop(AppCase):
|
|
|
|
|
|
def test_poll_readable(self):
|
|
|
x = X(self.app)
|
|
|
- x.hub.readers = {6: Mock()}
|
|
|
- x.close_then_error(x.connection.drain_nowait, mod=4)
|
|
|
+ reader = Mock(name='reader')
|
|
|
+ x.hub.add_reader(6, reader)
|
|
|
+ x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), mod=4))
|
|
|
x.hub.poller.poll.return_value = [(6, READ)]
|
|
|
with self.assertRaises(socket.error):
|
|
|
asynloop(*x.args)
|
|
|
- x.hub.readers[6].assert_called_with(6, READ)
|
|
|
+ reader.assert_called_with(6, READ)
|
|
|
self.assertTrue(x.hub.poller.poll.called)
|
|
|
|
|
|
def test_poll_readable_raises_Empty(self):
|
|
|
x = X(self.app)
|
|
|
- x.hub.readers = {6: Mock()}
|
|
|
- x.close_then_error(x.connection.drain_nowait)
|
|
|
+ reader = Mock(name='reader')
|
|
|
+ x.hub.add_reader(6, reader)
|
|
|
+ x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
|
|
|
x.hub.poller.poll.return_value = [(6, READ)]
|
|
|
- x.hub.readers[6].side_effect = Empty()
|
|
|
+ reader.side_effect = Empty()
|
|
|
with self.assertRaises(socket.error):
|
|
|
asynloop(*x.args)
|
|
|
- x.hub.readers[6].assert_called_with(6, READ)
|
|
|
+ reader.assert_called_with(6, READ)
|
|
|
self.assertTrue(x.hub.poller.poll.called)
|
|
|
|
|
|
def test_poll_writable(self):
|
|
|
x = X(self.app)
|
|
|
- x.hub.writers = {6: Mock()}
|
|
|
- x.close_then_error(x.connection.drain_nowait)
|
|
|
+ writer = Mock(name='writer')
|
|
|
+ x.hub.add_writer(6, writer)
|
|
|
+ x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
|
|
|
x.hub.poller.poll.return_value = [(6, WRITE)]
|
|
|
with self.assertRaises(socket.error):
|
|
|
asynloop(*x.args)
|
|
|
- x.hub.writers[6].assert_called_with(6, WRITE)
|
|
|
+ writer.assert_called_with(6, WRITE)
|
|
|
self.assertTrue(x.hub.poller.poll.called)
|
|
|
|
|
|
def test_poll_writable_none_registered(self):
|
|
|
x = X(self.app)
|
|
|
- x.hub.writers = {6: Mock()}
|
|
|
- x.close_then_error(x.connection.drain_nowait)
|
|
|
+ writer = Mock(name='writer')
|
|
|
+ x.hub.add_writer(6, writer)
|
|
|
+ x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
|
|
|
x.hub.poller.poll.return_value = [(7, WRITE)]
|
|
|
with self.assertRaises(socket.error):
|
|
|
asynloop(*x.args)
|
|
@@ -249,8 +270,9 @@ class test_asynloop(AppCase):
|
|
|
|
|
|
def test_poll_unknown_event(self):
|
|
|
x = X(self.app)
|
|
|
- x.hub.writers = {6: Mock()}
|
|
|
- x.close_then_error(x.connection.drain_nowait)
|
|
|
+ writer = Mock(name='reader')
|
|
|
+ x.hub.add_writer(6, writer)
|
|
|
+ x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
|
|
|
x.hub.poller.poll.return_value = [(6, 0)]
|
|
|
with self.assertRaises(socket.error):
|
|
|
asynloop(*x.args)
|
|
@@ -275,24 +297,26 @@ class test_asynloop(AppCase):
|
|
|
|
|
|
def test_poll_err_writable(self):
|
|
|
x = X(self.app)
|
|
|
- x.hub.writers = {6: Mock()}
|
|
|
- x.close_then_error(x.connection.drain_nowait)
|
|
|
+ writer = Mock(name='writer')
|
|
|
+ x.hub.add_writer(6, writer, 48)
|
|
|
+ x.hub.on_tick.add(x.close_then_error(Mock(), 2))
|
|
|
x.hub.poller.poll.return_value = [(6, ERR)]
|
|
|
with self.assertRaises(socket.error):
|
|
|
asynloop(*x.args)
|
|
|
- x.hub.writers[6].assert_called_with(6, ERR)
|
|
|
+ writer.assert_called_with(6, ERR, 48)
|
|
|
self.assertTrue(x.hub.poller.poll.called)
|
|
|
|
|
|
def test_poll_write_generator(self):
|
|
|
x = X(self.app)
|
|
|
+ x.hub.remove = Mock(name='hub.remove()')
|
|
|
|
|
|
def Gen():
|
|
|
yield 1
|
|
|
yield 2
|
|
|
gen = Gen()
|
|
|
|
|
|
- x.hub.writers = {6: gen}
|
|
|
- x.close_then_error(x.connection.drain_nowait)
|
|
|
+ x.hub.add_writer(6, gen)
|
|
|
+ x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
|
|
|
x.hub.poller.poll.return_value = [(6, WRITE)]
|
|
|
with self.assertRaises(socket.error):
|
|
|
asynloop(*x.args)
|
|
@@ -306,9 +330,10 @@ class test_asynloop(AppCase):
|
|
|
raise StopIteration()
|
|
|
yield
|
|
|
gen = Gen()
|
|
|
- x.hub.writers = {6: gen}
|
|
|
- x.close_then_error(x.connection.drain_nowait)
|
|
|
+ x.hub.add_writer(6, gen)
|
|
|
+ x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
|
|
|
x.hub.poller.poll.return_value = [(6, WRITE)]
|
|
|
+ x.hub.remove = Mock(name='hub.remove()')
|
|
|
with self.assertRaises(socket.error):
|
|
|
asynloop(*x.args)
|
|
|
self.assertIsNone(gen.gi_frame)
|
|
@@ -321,8 +346,9 @@ class test_asynloop(AppCase):
|
|
|
raise ValueError('foo')
|
|
|
yield
|
|
|
gen = Gen()
|
|
|
- x.hub.writers = {6: gen}
|
|
|
- x.close_then_error(x.connection.drain_nowait)
|
|
|
+ x.hub.add_writer(6, gen)
|
|
|
+ x.hub.remove = Mock(name='hub.remove()')
|
|
|
+ x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
|
|
|
x.hub.poller.poll.return_value = [(6, WRITE)]
|
|
|
with self.assertRaises(ValueError):
|
|
|
asynloop(*x.args)
|
|
@@ -331,19 +357,19 @@ class test_asynloop(AppCase):
|
|
|
|
|
|
def test_poll_err_readable(self):
|
|
|
x = X(self.app)
|
|
|
- x.hub.readers = {6: Mock()}
|
|
|
- x.close_then_error(x.connection.drain_nowait)
|
|
|
+ reader = Mock(name='reader')
|
|
|
+ x.hub.add_reader(6, reader, 24)
|
|
|
+ x.hub.on_tick.add(x.close_then_error(Mock(), 2))
|
|
|
x.hub.poller.poll.return_value = [(6, ERR)]
|
|
|
with self.assertRaises(socket.error):
|
|
|
asynloop(*x.args)
|
|
|
- x.hub.readers[6].assert_called_with(6, ERR)
|
|
|
+ reader.assert_called_with(6, ERR, 24)
|
|
|
self.assertTrue(x.hub.poller.poll.called)
|
|
|
|
|
|
def test_poll_raises_ValueError(self):
|
|
|
x = X(self.app)
|
|
|
x.hub.readers = {6: Mock()}
|
|
|
- x.close_then_error(x.connection.drain_nowait)
|
|
|
- x.hub.poller.poll.side_effect = ValueError()
|
|
|
+ x.close_then_error(x.hub.poller.poll, exc=ValueError)
|
|
|
asynloop(*x.args)
|
|
|
self.assertTrue(x.hub.poller.poll.called)
|
|
|
|