|
@@ -6,7 +6,7 @@ from Queue import Empty
|
|
|
|
|
|
from kombu.transport.base import Message
|
|
|
from kombu.connection import BrokerConnection
|
|
|
-from celery.utils.timer2 import Timer
|
|
|
+from mock import Mock
|
|
|
|
|
|
from celery import current_app
|
|
|
from celery.concurrency.base import BasePool
|
|
@@ -21,37 +21,20 @@ from celery.worker.job import TaskRequest
|
|
|
from celery.worker.consumer import Consumer as MainConsumer
|
|
|
from celery.worker.consumer import QoS, RUN, PREFETCH_COUNT_MAX
|
|
|
from celery.utils.serialization import pickle
|
|
|
+from celery.utils.timer2 import Timer
|
|
|
|
|
|
from celery.tests.compat import catch_warnings
|
|
|
from celery.tests.utils import unittest
|
|
|
from celery.tests.utils import AppCase, execute_context, skip
|
|
|
|
|
|
|
|
|
-class MockConsumer(object):
|
|
|
-
|
|
|
- class Channel(object):
|
|
|
-
|
|
|
- def close(self):
|
|
|
- pass
|
|
|
-
|
|
|
- def register_callback(self, cb):
|
|
|
- pass
|
|
|
-
|
|
|
- def consume(self):
|
|
|
- pass
|
|
|
-
|
|
|
- @property
|
|
|
- def channel(self):
|
|
|
- return self.Channel()
|
|
|
-
|
|
|
-
|
|
|
class PlaceHolder(object):
|
|
|
pass
|
|
|
|
|
|
|
|
|
class MyKombuConsumer(MainConsumer):
|
|
|
- broadcast_consumer = MockConsumer()
|
|
|
- task_consumer = MockConsumer()
|
|
|
+ broadcast_consumer = Mock()
|
|
|
+ task_consumer = Mock()
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
kwargs.setdefault("pool", BasePool(2))
|
|
@@ -101,75 +84,6 @@ def foo_periodic_task():
|
|
|
return "foo"
|
|
|
|
|
|
|
|
|
-class MockLogger(object):
|
|
|
-
|
|
|
- def __init__(self):
|
|
|
- self.logged = []
|
|
|
-
|
|
|
- def critical(self, msg, *args, **kwargs):
|
|
|
- self.logged.append(msg)
|
|
|
-
|
|
|
- def info(self, msg, *args, **kwargs):
|
|
|
- self.logged.append(msg)
|
|
|
-
|
|
|
- def error(self, msg, *args, **kwargs):
|
|
|
- self.logged.append(msg)
|
|
|
-
|
|
|
- def debug(self, msg, *args, **kwargs):
|
|
|
- self.logged.append(msg)
|
|
|
-
|
|
|
-
|
|
|
-class MockBackend(object):
|
|
|
- _acked = False
|
|
|
-
|
|
|
- def basic_ack(self, delivery_tag):
|
|
|
- self._acked = True
|
|
|
-
|
|
|
-
|
|
|
-class MockPool(BasePool):
|
|
|
- _terminated = False
|
|
|
- _stopped = False
|
|
|
-
|
|
|
- def __init__(self, *args, **kwargs):
|
|
|
- self.raise_regular = kwargs.get("raise_regular", False)
|
|
|
- self.raise_base = kwargs.get("raise_base", False)
|
|
|
- self.raise_SystemTerminate = kwargs.get("raise_SystemTerminate",
|
|
|
- False)
|
|
|
-
|
|
|
- def apply_async(self, *args, **kwargs):
|
|
|
- if self.raise_regular:
|
|
|
- raise KeyError("some exception")
|
|
|
- if self.raise_base:
|
|
|
- raise KeyboardInterrupt("Ctrl+c")
|
|
|
- if self.raise_SystemTerminate:
|
|
|
- raise SystemTerminate()
|
|
|
-
|
|
|
- def start(self):
|
|
|
- pass
|
|
|
-
|
|
|
- def stop(self):
|
|
|
- self._stopped = True
|
|
|
- return True
|
|
|
-
|
|
|
- def terminate(self):
|
|
|
- self._terminated = True
|
|
|
- self.stop()
|
|
|
-
|
|
|
-
|
|
|
-class MockController(object):
|
|
|
-
|
|
|
- def __init__(self, w, *args, **kwargs):
|
|
|
- self._w = w
|
|
|
- self._stopped = False
|
|
|
-
|
|
|
- def start(self):
|
|
|
- self._w["started"] = True
|
|
|
- self._stopped = False
|
|
|
-
|
|
|
- def stop(self):
|
|
|
- self._stopped = True
|
|
|
-
|
|
|
-
|
|
|
def create_message(backend, **data):
|
|
|
data.setdefault("id", gen_unique_id())
|
|
|
return Message(backend, body=pickle.dumps(dict(**data)),
|
|
@@ -231,15 +145,8 @@ class test_QoS(unittest.TestCase):
|
|
|
threaded([add, sub]) # n = 2
|
|
|
self.assertEqual(qos.value, 1000)
|
|
|
|
|
|
- class MockConsumer(object):
|
|
|
- prefetch_count = 0
|
|
|
-
|
|
|
- def qos(self, prefetch_size=0, prefetch_count=0, apply_global=False):
|
|
|
- self.prefetch_count = prefetch_count
|
|
|
-
|
|
|
def test_exceeds_short(self):
|
|
|
- consumer = self.MockConsumer()
|
|
|
- qos = QoS(consumer, PREFETCH_COUNT_MAX - 1,
|
|
|
+ qos = QoS(Mock(), PREFETCH_COUNT_MAX - 1,
|
|
|
current_app.log.get_default_logger())
|
|
|
qos.update()
|
|
|
self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1)
|
|
@@ -253,17 +160,17 @@ class test_QoS(unittest.TestCase):
|
|
|
self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1)
|
|
|
|
|
|
def test_consumer_increment_decrement(self):
|
|
|
- consumer = self.MockConsumer()
|
|
|
+ consumer = Mock()
|
|
|
qos = QoS(consumer, 10, current_app.log.get_default_logger())
|
|
|
qos.update()
|
|
|
self.assertEqual(qos.value, 10)
|
|
|
- self.assertEqual(consumer.prefetch_count, 10)
|
|
|
+ self.assertIn({"prefetch_count": 10}, consumer.qos.call_args)
|
|
|
qos.decrement()
|
|
|
self.assertEqual(qos.value, 9)
|
|
|
- self.assertEqual(consumer.prefetch_count, 9)
|
|
|
+ self.assertIn({"prefetch_count": 9}, consumer.qos.call_args)
|
|
|
qos.decrement_eventually()
|
|
|
self.assertEqual(qos.value, 8)
|
|
|
- self.assertEqual(consumer.prefetch_count, 9)
|
|
|
+ self.assertIn({"prefetch_count": 9}, consumer.qos.call_args)
|
|
|
|
|
|
# Does not decrement 0 value
|
|
|
qos.value = 0
|
|
@@ -340,7 +247,7 @@ class test_Consumer(unittest.TestCase):
|
|
|
def test_receive_message_unknown(self):
|
|
|
l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
|
|
|
send_events=False)
|
|
|
- backend = MockBackend()
|
|
|
+ backend = Mock()
|
|
|
m = create_message(backend, unknown={"baz": "!!!"})
|
|
|
l.event_dispatcher = MockEventDispatcher()
|
|
|
l.pidbox_node = MockNode()
|
|
@@ -356,7 +263,7 @@ class test_Consumer(unittest.TestCase):
|
|
|
def test_receive_message_eta_OverflowError(self):
|
|
|
l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
|
|
|
send_events=False)
|
|
|
- backend = MockBackend()
|
|
|
+ backend = Mock()
|
|
|
called = [False]
|
|
|
|
|
|
def to_timestamp(d):
|
|
@@ -379,20 +286,20 @@ class test_Consumer(unittest.TestCase):
|
|
|
timer2.to_timestamp = prev
|
|
|
|
|
|
def test_receive_message_InvalidTaskError(self):
|
|
|
- logger = MockLogger()
|
|
|
+ logger = Mock()
|
|
|
l = MyKombuConsumer(self.ready_queue, self.eta_schedule, logger,
|
|
|
send_events=False)
|
|
|
- backend = MockBackend()
|
|
|
+ backend = Mock()
|
|
|
m = create_message(backend, task=foo_task.name,
|
|
|
args=(1, 2), kwargs="foobarbaz", id=1)
|
|
|
l.event_dispatcher = MockEventDispatcher()
|
|
|
l.pidbox_node = MockNode()
|
|
|
|
|
|
l.receive_message(m.decode(), m)
|
|
|
- self.assertIn("Invalid task ignored", logger.logged[0])
|
|
|
+ self.assertIn("Invalid task ignored", logger.error.call_args[0][0])
|
|
|
|
|
|
def test_on_decode_error(self):
|
|
|
- logger = MockLogger()
|
|
|
+ logger = Mock()
|
|
|
l = MyKombuConsumer(self.ready_queue, self.eta_schedule, logger,
|
|
|
send_events=False)
|
|
|
|
|
@@ -408,12 +315,13 @@ class test_Consumer(unittest.TestCase):
|
|
|
message = MockMessage()
|
|
|
l.on_decode_error(message, KeyError("foo"))
|
|
|
self.assertTrue(message.acked)
|
|
|
- self.assertIn("Can't decode message body", logger.logged[0])
|
|
|
+ self.assertIn("Can't decode message body",
|
|
|
+ logger.critical.call_args[0][0])
|
|
|
|
|
|
def test_receieve_message(self):
|
|
|
l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
|
|
|
send_events=False)
|
|
|
- backend = MockBackend()
|
|
|
+ backend = Mock()
|
|
|
m = create_message(backend, task=foo_task.name,
|
|
|
args=[2, 4, 8], kwargs={})
|
|
|
|
|
@@ -520,7 +428,7 @@ class test_Consumer(unittest.TestCase):
|
|
|
|
|
|
l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
|
|
|
send_events=False)
|
|
|
- backend = MockBackend()
|
|
|
+ backend = Mock()
|
|
|
m = create_message(backend, task=foo_task.name,
|
|
|
eta=datetime.now().isoformat(),
|
|
|
args=[2, 4, 8], kwargs={})
|
|
@@ -544,7 +452,7 @@ class test_Consumer(unittest.TestCase):
|
|
|
ready_queue = FastQueue()
|
|
|
l = MyKombuConsumer(ready_queue, self.eta_schedule, self.logger,
|
|
|
send_events=False)
|
|
|
- backend = MockBackend()
|
|
|
+ backend = Mock()
|
|
|
id = gen_unique_id()
|
|
|
t = create_message(backend, task=foo_task.name, args=[2, 4, 8],
|
|
|
kwargs={}, id=id)
|
|
@@ -557,7 +465,7 @@ class test_Consumer(unittest.TestCase):
|
|
|
def test_receieve_message_not_registered(self):
|
|
|
l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
|
|
|
send_events=False)
|
|
|
- backend = MockBackend()
|
|
|
+ backend = Mock()
|
|
|
m = create_message(backend, task="x.X.31x", args=[2, 4, 8], kwargs={})
|
|
|
|
|
|
l.event_dispatcher = MockEventDispatcher()
|
|
@@ -569,7 +477,7 @@ class test_Consumer(unittest.TestCase):
|
|
|
l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
|
|
|
send_events=False)
|
|
|
l.event_dispatcher = MockEventDispatcher()
|
|
|
- backend = MockBackend()
|
|
|
+ backend = Mock()
|
|
|
m = create_message(backend, task=foo_task.name,
|
|
|
args=[2, 4, 8], kwargs={},
|
|
|
eta=(datetime.now() +
|
|
@@ -619,8 +527,8 @@ class test_Consumer(unittest.TestCase):
|
|
|
|
|
|
l = _Consumer(self.ready_queue, self.eta_schedule, self.logger,
|
|
|
send_events=False, init_callback=init_callback)
|
|
|
- l.task_consumer = MockConsumer()
|
|
|
- l.broadcast_consumer = MockConsumer()
|
|
|
+ l.task_consumer = Mock()
|
|
|
+ l.broadcast_consumer = Mock()
|
|
|
l.qos = _QoS()
|
|
|
l.connection = BrokerConnection()
|
|
|
l.iterations = 0
|
|
@@ -641,8 +549,8 @@ class test_Consumer(unittest.TestCase):
|
|
|
l = _Consumer(self.ready_queue, self.eta_schedule, self.logger,
|
|
|
send_events=False, init_callback=init_callback)
|
|
|
l.qos = _QoS()
|
|
|
- l.task_consumer = MockConsumer()
|
|
|
- l.broadcast_consumer = MockConsumer()
|
|
|
+ l.task_consumer = Mock()
|
|
|
+ l.broadcast_consumer = Mock()
|
|
|
l.connection = BrokerConnection()
|
|
|
|
|
|
def raises_socket_error(limit=None):
|
|
@@ -662,7 +570,7 @@ class test_WorkController(AppCase):
|
|
|
|
|
|
def create_worker(self, **kw):
|
|
|
worker = WorkController(concurrency=1, loglevel=0, **kw)
|
|
|
- worker.logger = MockLogger()
|
|
|
+ worker.logger = Mock()
|
|
|
return worker
|
|
|
|
|
|
def test_process_initializer(self):
|
|
@@ -759,7 +667,7 @@ class test_WorkController(AppCase):
|
|
|
|
|
|
def test_on_timer_error(self):
|
|
|
worker = WorkController(concurrency=1, loglevel=0)
|
|
|
- worker.logger = MockLogger()
|
|
|
+ worker.logger = Mock()
|
|
|
|
|
|
try:
|
|
|
raise KeyError("foo")
|
|
@@ -767,32 +675,34 @@ class test_WorkController(AppCase):
|
|
|
exc_info = sys.exc_info()
|
|
|
|
|
|
worker.on_timer_error(exc_info)
|
|
|
- logged = worker.logger.logged[0]
|
|
|
+ logged = worker.logger.error.call_args[0][0]
|
|
|
self.assertIn("KeyError", logged)
|
|
|
|
|
|
def test_on_timer_tick(self):
|
|
|
worker = WorkController(concurrency=1, loglevel=10)
|
|
|
- worker.logger = MockLogger()
|
|
|
+ worker.logger = Mock()
|
|
|
worker.timer_debug = worker.logger.debug
|
|
|
|
|
|
worker.on_timer_tick(30.0)
|
|
|
- logged = worker.logger.logged[0]
|
|
|
+ logged = worker.logger.debug.call_args[0][0]
|
|
|
self.assertIn("30.0", logged)
|
|
|
|
|
|
def test_process_task(self):
|
|
|
worker = self.worker
|
|
|
- worker.pool = MockPool()
|
|
|
- backend = MockBackend()
|
|
|
+ worker.pool = Mock()
|
|
|
+ backend = Mock()
|
|
|
m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
|
|
|
kwargs={})
|
|
|
task = TaskRequest.from_message(m, m.decode())
|
|
|
worker.process_task(task)
|
|
|
+ self.assertEqual(worker.pool.apply_async.call_count, 1)
|
|
|
worker.pool.stop()
|
|
|
|
|
|
def test_process_task_raise_base(self):
|
|
|
worker = self.worker
|
|
|
- worker.pool = MockPool(raise_base=True)
|
|
|
- backend = MockBackend()
|
|
|
+ worker.pool = Mock()
|
|
|
+ worker.pool.apply_async.side_effect = KeyboardInterrupt("Ctrl+C")
|
|
|
+ backend = Mock()
|
|
|
m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
|
|
|
kwargs={})
|
|
|
task = TaskRequest.from_message(m, m.decode())
|
|
@@ -803,8 +713,9 @@ class test_WorkController(AppCase):
|
|
|
|
|
|
def test_process_task_raise_SystemTerminate(self):
|
|
|
worker = self.worker
|
|
|
- worker.pool = MockPool(raise_SystemTerminate=True)
|
|
|
- backend = MockBackend()
|
|
|
+ worker.pool = Mock()
|
|
|
+ worker.pool.apply_async.side_effect = SystemTerminate()
|
|
|
+ backend = Mock()
|
|
|
m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
|
|
|
kwargs={})
|
|
|
task = TaskRequest.from_message(m, m.decode())
|
|
@@ -815,8 +726,9 @@ class test_WorkController(AppCase):
|
|
|
|
|
|
def test_process_task_raise_regular(self):
|
|
|
worker = self.worker
|
|
|
- worker.pool = MockPool(raise_regular=True)
|
|
|
- backend = MockBackend()
|
|
|
+ worker.pool = Mock()
|
|
|
+ worker.pool.apply_async.side_effect = KeyError("some exception")
|
|
|
+ backend = Mock()
|
|
|
m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
|
|
|
kwargs={})
|
|
|
task = TaskRequest.from_message(m, m.decode())
|
|
@@ -824,43 +736,26 @@ class test_WorkController(AppCase):
|
|
|
worker.pool.stop()
|
|
|
|
|
|
def test_start_catches_base_exceptions(self):
|
|
|
-
|
|
|
- class Component(object):
|
|
|
- stopped = False
|
|
|
- terminated = False
|
|
|
-
|
|
|
- def __init__(self, exc):
|
|
|
- self.exc = exc
|
|
|
-
|
|
|
- def start(self):
|
|
|
- raise self.exc
|
|
|
-
|
|
|
- def terminate(self):
|
|
|
- self.terminated = True
|
|
|
-
|
|
|
- def stop(self):
|
|
|
- self.stopped = True
|
|
|
-
|
|
|
worker1 = self.create_worker()
|
|
|
- worker1.components = [Component(SystemTerminate())]
|
|
|
+ stc = Mock()
|
|
|
+ stc.start.side_effect = SystemTerminate()
|
|
|
+ worker1.components = [stc]
|
|
|
self.assertRaises(SystemExit, worker1.start)
|
|
|
- self.assertTrue(worker1.components[0].terminated)
|
|
|
+ self.assertTrue(stc.terminate.call_count)
|
|
|
|
|
|
worker2 = self.create_worker()
|
|
|
- worker2.components = [Component(SystemExit())]
|
|
|
+ sec = Mock()
|
|
|
+ sec.start.side_effect = SystemExit()
|
|
|
+ sec.terminate = None
|
|
|
+ worker2.components = [sec]
|
|
|
self.assertRaises(SystemExit, worker2.start)
|
|
|
- self.assertTrue(worker2.components[0].stopped)
|
|
|
+ self.assertTrue(sec.stop.call_count)
|
|
|
|
|
|
def test_state_db(self):
|
|
|
from celery.worker import state
|
|
|
Persistent = state.Persistent
|
|
|
|
|
|
- class MockPersistent(Persistent):
|
|
|
-
|
|
|
- def _load(self):
|
|
|
- return {}
|
|
|
-
|
|
|
- state.Persistent = MockPersistent
|
|
|
+ state.Persistent = Mock()
|
|
|
try:
|
|
|
worker = self.create_worker(db="statefilename")
|
|
|
self.assertTrue(worker._finalize_db)
|
|
@@ -878,37 +773,27 @@ class test_WorkController(AppCase):
|
|
|
|
|
|
def test_start__stop(self):
|
|
|
worker = self.worker
|
|
|
- w1 = {"started": False}
|
|
|
- w2 = {"started": False}
|
|
|
- w3 = {"started": False}
|
|
|
- w4 = {"started": False}
|
|
|
- worker.components = [MockController(w1), MockController(w2),
|
|
|
- MockController(w3), MockController(w4)]
|
|
|
+ worker.components = [Mock(), Mock(), Mock(), Mock()]
|
|
|
|
|
|
worker.start()
|
|
|
- for w in (w1, w2, w3, w4):
|
|
|
- self.assertTrue(w["started"])
|
|
|
- self.assertTrue(worker._running, len(worker.components))
|
|
|
+ for w in worker.components:
|
|
|
+ self.assertTrue(w.start.call_count)
|
|
|
worker.stop()
|
|
|
for component in worker.components:
|
|
|
- self.assertTrue(component._stopped)
|
|
|
+ self.assertTrue(w.stop.call_count)
|
|
|
|
|
|
def test_start__terminate(self):
|
|
|
worker = self.worker
|
|
|
- w1 = {"started": False}
|
|
|
- w2 = {"started": False}
|
|
|
- w3 = {"started": False}
|
|
|
- w4 = {"started": False}
|
|
|
- worker.components = [MockController(w1), MockController(w2),
|
|
|
- MockController(w3), MockController(w4),
|
|
|
- MockPool()]
|
|
|
+ worker.components = [Mock(), Mock(), Mock(), Mock(), Mock()]
|
|
|
+ for component in worker.components[:3]:
|
|
|
+ component.terminate = None
|
|
|
|
|
|
worker.start()
|
|
|
- for w in (w1, w2, w3, w4):
|
|
|
- self.assertTrue(w["started"])
|
|
|
+ for w in worker.components[:3]:
|
|
|
+ self.assertTrue(w.start.call_count)
|
|
|
self.assertTrue(worker._running, len(worker.components))
|
|
|
self.assertEqual(worker._state, RUN)
|
|
|
worker.terminate()
|
|
|
- for component in worker.components:
|
|
|
- self.assertTrue(component._stopped)
|
|
|
- self.assertTrue(worker.components[4]._terminated)
|
|
|
+ for component in worker.components[:3]:
|
|
|
+ self.assertTrue(component.stop.call_count)
|
|
|
+ self.assertTrue(worker.components[4].terminate.call_count)
|