|
@@ -7,7 +7,7 @@ from Queue import Empty
|
|
|
|
|
|
from kombu.transport.base import Message
|
|
|
from kombu.connection import BrokerConnection
|
|
|
-from mock import Mock
|
|
|
+from mock import Mock, patch
|
|
|
|
|
|
from celery import current_app
|
|
|
from celery.concurrency.base import BasePool
|
|
@@ -262,38 +262,28 @@ class test_Consumer(unittest.TestCase):
|
|
|
context = catch_warnings(record=True)
|
|
|
execute_context(context, with_catch_warnings)
|
|
|
|
|
|
- def test_receive_message_eta_OverflowError(self):
|
|
|
+ @patch("celery.utils.timer2.to_timestamp")
|
|
|
+ def test_receive_message_eta_OverflowError(self, to_timestamp):
|
|
|
+ to_timestamp.side_effect = OverflowError()
|
|
|
l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
|
|
|
send_events=False)
|
|
|
- backend = Mock()
|
|
|
- called = [False]
|
|
|
-
|
|
|
- def to_timestamp(d):
|
|
|
- called[0] = True
|
|
|
- raise OverflowError()
|
|
|
-
|
|
|
- m = create_message(backend, task=foo_task.name,
|
|
|
+ m = create_message(Mock(), task=foo_task.name,
|
|
|
args=("2, 2"),
|
|
|
kwargs={},
|
|
|
eta=datetime.now().isoformat())
|
|
|
l.event_dispatcher = Mock()
|
|
|
l.pidbox_node = MockNode()
|
|
|
|
|
|
- prev, timer2.to_timestamp = timer2.to_timestamp, to_timestamp
|
|
|
- try:
|
|
|
- l.receive_message(m.decode(), m)
|
|
|
- self.assertTrue(m.acknowledged)
|
|
|
- self.assertTrue(called[0])
|
|
|
- finally:
|
|
|
- timer2.to_timestamp = prev
|
|
|
+ l.receive_message(m.decode(), m)
|
|
|
+ self.assertTrue(m.acknowledged)
|
|
|
+ self.assertTrue(to_timestamp.call_count)
|
|
|
|
|
|
def test_receive_message_InvalidTaskError(self):
|
|
|
logger = Mock()
|
|
|
l = MyKombuConsumer(self.ready_queue, self.eta_schedule, logger,
|
|
|
send_events=False)
|
|
|
- backend = Mock()
|
|
|
- m = create_message(backend, task=foo_task.name,
|
|
|
- args=(1, 2), kwargs="foobarbaz", id=1)
|
|
|
+ m = create_message(Mock(), task=foo_task.name,
|
|
|
+ args=(1, 2), kwargs="foobarbaz", id=1)
|
|
|
l.event_dispatcher = Mock()
|
|
|
l.pidbox_node = MockNode()
|
|
|
|
|
@@ -306,26 +296,21 @@ class test_Consumer(unittest.TestCase):
|
|
|
l = MyKombuConsumer(self.ready_queue, self.eta_schedule, logger,
|
|
|
send_events=False)
|
|
|
|
|
|
- class MockMessage(object):
|
|
|
+ class MockMessage(Mock):
|
|
|
content_type = "application/x-msgpack"
|
|
|
content_encoding = "binary"
|
|
|
body = "foobarbaz"
|
|
|
- acked = False
|
|
|
-
|
|
|
- def ack(self):
|
|
|
- self.acked = True
|
|
|
|
|
|
message = MockMessage()
|
|
|
l.on_decode_error(message, KeyError("foo"))
|
|
|
- self.assertTrue(message.acked)
|
|
|
+ self.assertTrue(message.ack.call_count)
|
|
|
self.assertIn("Can't decode message body",
|
|
|
logger.critical.call_args[0][0])
|
|
|
|
|
|
def test_receieve_message(self):
|
|
|
l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
|
|
|
send_events=False)
|
|
|
- backend = Mock()
|
|
|
- m = create_message(backend, task=foo_task.name,
|
|
|
+ m = create_message(Mock(), task=foo_task.name,
|
|
|
args=[2, 4, 8], kwargs={})
|
|
|
|
|
|
l.event_dispatcher = Mock()
|
|
@@ -363,51 +348,31 @@ class test_Consumer(unittest.TestCase):
|
|
|
def drain_events(self, **kwargs):
|
|
|
self.obj.connection = None
|
|
|
|
|
|
- class Consumer(object):
|
|
|
- consuming = False
|
|
|
- prefetch_count = 0
|
|
|
-
|
|
|
- def consume(self):
|
|
|
- self.consuming = True
|
|
|
-
|
|
|
- def qos(self, prefetch_size=0, prefetch_count=0,
|
|
|
- apply_global=False):
|
|
|
- self.prefetch_count = prefetch_count
|
|
|
-
|
|
|
l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
|
|
|
send_events=False)
|
|
|
l.connection = Connection()
|
|
|
l.connection.obj = l
|
|
|
- l.task_consumer = Consumer()
|
|
|
+ l.task_consumer = Mock()
|
|
|
l.qos = QoS(l.task_consumer, 10, l.logger)
|
|
|
|
|
|
l.consume_messages()
|
|
|
l.consume_messages()
|
|
|
- self.assertTrue(l.task_consumer.consuming)
|
|
|
- self.assertEqual(l.task_consumer.prefetch_count, 10)
|
|
|
-
|
|
|
+ self.assertTrue(l.task_consumer.consume.call_count)
|
|
|
+ l.task_consumer.qos.assert_called_with(prefetch_count=10)
|
|
|
l.qos.decrement()
|
|
|
l.consume_messages()
|
|
|
- self.assertEqual(l.task_consumer.prefetch_count, 9)
|
|
|
+ l.task_consumer.qos.assert_called_with(prefetch_count=9)
|
|
|
|
|
|
def test_maybe_conn_error(self):
|
|
|
-
|
|
|
- def raises(error):
|
|
|
-
|
|
|
- def fun():
|
|
|
- raise error
|
|
|
-
|
|
|
- return fun
|
|
|
-
|
|
|
l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
|
|
|
send_events=False)
|
|
|
l.connection_errors = (KeyError, )
|
|
|
l.channel_errors = (SyntaxError, )
|
|
|
- l.maybe_conn_error(raises(AttributeError("foo")))
|
|
|
- l.maybe_conn_error(raises(KeyError("foo")))
|
|
|
- l.maybe_conn_error(raises(SyntaxError("foo")))
|
|
|
+ l.maybe_conn_error(Mock(side_effect=AttributeError("foo")))
|
|
|
+ l.maybe_conn_error(Mock(side_effect=KeyError("foo")))
|
|
|
+ l.maybe_conn_error(Mock(side_effect=SyntaxError("foo")))
|
|
|
self.assertRaises(IndexError, l.maybe_conn_error,
|
|
|
- raises(IndexError("foo")))
|
|
|
+ Mock(side_effect=IndexError("foo")))
|
|
|
|
|
|
def test_apply_eta_task(self):
|
|
|
from celery.worker import state
|
|
@@ -423,21 +388,13 @@ class test_Consumer(unittest.TestCase):
|
|
|
self.assertIs(self.ready_queue.get_nowait(), task)
|
|
|
|
|
|
def test_receieve_message_eta_isoformat(self):
|
|
|
-
|
|
|
- class MockConsumer(object):
|
|
|
- prefetch_count_incremented = False
|
|
|
-
|
|
|
- def qos(self, **kwargs):
|
|
|
- self.prefetch_count_incremented = True
|
|
|
-
|
|
|
l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
|
|
|
send_events=False)
|
|
|
- backend = Mock()
|
|
|
- m = create_message(backend, task=foo_task.name,
|
|
|
+ m = create_message(Mock(), task=foo_task.name,
|
|
|
eta=datetime.now().isoformat(),
|
|
|
args=[2, 4, 8], kwargs={})
|
|
|
|
|
|
- l.task_consumer = MockConsumer()
|
|
|
+ l.task_consumer = Mock()
|
|
|
l.qos = QoS(l.task_consumer, l.initial_prefetch_count, l.logger)
|
|
|
l.event_dispatcher = Mock()
|
|
|
l.receive_message(m.decode(), m)
|
|
@@ -449,7 +406,7 @@ class test_Consumer(unittest.TestCase):
|
|
|
if item.args[0].task_name == foo_task.name:
|
|
|
found = True
|
|
|
self.assertTrue(found)
|
|
|
- self.assertTrue(l.task_consumer.prefetch_count_incremented)
|
|
|
+ self.assertTrue(l.task_consumer.qos.call_count)
|
|
|
l.eta_schedule.stop()
|
|
|
|
|
|
def test_revoke(self):
|
|
@@ -519,17 +476,12 @@ class test_Consumer(unittest.TestCase):
|
|
|
|
|
|
class _Consumer(MyKombuConsumer):
|
|
|
iterations = 0
|
|
|
- wait_method = None
|
|
|
|
|
|
def reset_connection(self):
|
|
|
if self.iterations >= 1:
|
|
|
raise KeyError("foo")
|
|
|
|
|
|
- called_back = [False]
|
|
|
-
|
|
|
- def init_callback(consumer):
|
|
|
- called_back[0] = True
|
|
|
-
|
|
|
+ init_callback = Mock()
|
|
|
l = _Consumer(self.ready_queue, self.eta_schedule, self.logger,
|
|
|
send_events=False, init_callback=init_callback)
|
|
|
l.task_consumer = Mock()
|
|
@@ -547,25 +499,21 @@ class test_Consumer(unittest.TestCase):
|
|
|
|
|
|
l.consume_messages = raises_KeyError
|
|
|
self.assertRaises(KeyError, l.start)
|
|
|
- self.assertTrue(called_back[0])
|
|
|
+ self.assertTrue(init_callback.call_count)
|
|
|
self.assertEqual(l.iterations, 1)
|
|
|
self.assertEqual(l.qos.prev, l.qos.value)
|
|
|
|
|
|
+ init_callback.reset_mock()
|
|
|
l = _Consumer(self.ready_queue, self.eta_schedule, self.logger,
|
|
|
send_events=False, init_callback=init_callback)
|
|
|
l.qos = _QoS()
|
|
|
l.task_consumer = Mock()
|
|
|
l.broadcast_consumer = Mock()
|
|
|
l.connection = BrokerConnection()
|
|
|
-
|
|
|
- def raises_socket_error(limit=None):
|
|
|
- l.iterations = 1
|
|
|
- raise socket.error("foo")
|
|
|
-
|
|
|
- l.consume_messages = raises_socket_error
|
|
|
+ l.consume_messages = Mock(side_effect=socket.error("foo"))
|
|
|
self.assertRaises(socket.error, l.start)
|
|
|
- self.assertTrue(called_back[0])
|
|
|
- self.assertEqual(l.iterations, 1)
|
|
|
+ self.assertTrue(init_callback.call_count)
|
|
|
+ self.assertTrue(l.consume_messages.call_count)
|
|
|
|
|
|
|
|
|
class test_WorkController(AppCase):
|
|
@@ -578,55 +526,35 @@ class test_WorkController(AppCase):
|
|
|
worker.logger = Mock()
|
|
|
return worker
|
|
|
|
|
|
- def test_process_initializer(self):
|
|
|
+ @patch("celery.platforms.reset_signal")
|
|
|
+ @patch("celery.platforms.ignore_signal")
|
|
|
+ @patch("celery.platforms.set_mp_process_title")
|
|
|
+ def test_process_initializer(self, set_mp_process_title, ignore_signal,
|
|
|
+ reset_signal):
|
|
|
from celery import Celery
|
|
|
- from celery import platforms
|
|
|
from celery import signals
|
|
|
from celery.app import _tls
|
|
|
from celery.worker import process_initializer
|
|
|
from celery.worker import WORKER_SIGRESET, WORKER_SIGIGNORE
|
|
|
|
|
|
- ignored_signals = []
|
|
|
- reset_signals = []
|
|
|
- worker_init = [False]
|
|
|
- default_app = current_app
|
|
|
- app = Celery(loader="default", set_as_current=False)
|
|
|
-
|
|
|
- class Loader(object):
|
|
|
-
|
|
|
- def init_worker(self):
|
|
|
- worker_init[0] = True
|
|
|
- app.loader = Loader()
|
|
|
-
|
|
|
def on_worker_process_init(**kwargs):
|
|
|
on_worker_process_init.called = True
|
|
|
on_worker_process_init.called = False
|
|
|
signals.worker_process_init.connect(on_worker_process_init)
|
|
|
|
|
|
- def set_mp_process_title(title, hostname=None):
|
|
|
- set_mp_process_title.called = (title, hostname)
|
|
|
- set_mp_process_title.called = ()
|
|
|
-
|
|
|
- pignore_signal = platforms.ignore_signal
|
|
|
- preset_signal = platforms.reset_signal
|
|
|
- psetproctitle = platforms.set_mp_process_title
|
|
|
- platforms.ignore_signal = lambda sig: ignored_signals.append(sig)
|
|
|
- platforms.reset_signal = lambda sig: reset_signals.append(sig)
|
|
|
- platforms.set_mp_process_title = set_mp_process_title
|
|
|
- try:
|
|
|
- process_initializer(app, "awesome.worker.com")
|
|
|
- self.assertItemsEqual(ignored_signals, WORKER_SIGIGNORE)
|
|
|
- self.assertItemsEqual(reset_signals, WORKER_SIGRESET)
|
|
|
- self.assertTrue(worker_init[0])
|
|
|
- self.assertTrue(on_worker_process_init.called)
|
|
|
- self.assertIs(_tls.current_app, app)
|
|
|
- self.assertTupleEqual(set_mp_process_title.called,
|
|
|
- ("celeryd", "awesome.worker.com"))
|
|
|
- finally:
|
|
|
- platforms.ignore_signal = pignore_signal
|
|
|
- platforms.reset_signal = preset_signal
|
|
|
- platforms.set_mp_process_title = psetproctitle
|
|
|
- default_app.set_current()
|
|
|
+ app = Celery(loader=Mock(), set_as_current=False)
|
|
|
+ process_initializer(app, "awesome.worker.com")
|
|
|
+ for ignoresig in WORKER_SIGIGNORE:
|
|
|
+ self.assertIn(((ignoresig, ), {}),
|
|
|
+ ignore_signal.call_args_list)
|
|
|
+ for resetsig in WORKER_SIGRESET:
|
|
|
+ self.assertIn(((resetsig, ), {}),
|
|
|
+ reset_signal.call_args_list)
|
|
|
+ self.assertTrue(app.loader.init_worker.call_count)
|
|
|
+ self.assertTrue(on_worker_process_init.called)
|
|
|
+ self.assertIs(_tls.current_app, app)
|
|
|
+ set_mp_process_title.assert_called_with("celeryd",
|
|
|
+ hostname="awesome.worker.com")
|
|
|
|
|
|
def test_with_rate_limits_disabled(self):
|
|
|
worker = WorkController(concurrency=1, loglevel=0,
|