Explorar o código

Refactored some tests to use mock.Mock

Ask Solem %!s(int64=14) %!d(string=hai) anos
pai
achega
d741c2aeba

+ 65 - 180
celery/tests/test_worker/test_worker.py

@@ -6,7 +6,7 @@ from Queue import Empty
 
 from kombu.transport.base import Message
 from kombu.connection import BrokerConnection
-from celery.utils.timer2 import Timer
+from mock import Mock
 
 from celery import current_app
 from celery.concurrency.base import BasePool
@@ -21,37 +21,20 @@ from celery.worker.job import TaskRequest
 from celery.worker.consumer import Consumer as MainConsumer
 from celery.worker.consumer import QoS, RUN, PREFETCH_COUNT_MAX
 from celery.utils.serialization import pickle
+from celery.utils.timer2 import Timer
 
 from celery.tests.compat import catch_warnings
 from celery.tests.utils import unittest
 from celery.tests.utils import AppCase, execute_context, skip
 
 
-class MockConsumer(object):
-
-    class Channel(object):
-
-        def close(self):
-            pass
-
-    def register_callback(self, cb):
-        pass
-
-    def consume(self):
-        pass
-
-    @property
-    def channel(self):
-        return self.Channel()
-
-
 class PlaceHolder(object):
         pass
 
 
 class MyKombuConsumer(MainConsumer):
-    broadcast_consumer = MockConsumer()
-    task_consumer = MockConsumer()
+    broadcast_consumer = Mock()
+    task_consumer = Mock()
 
     def __init__(self, *args, **kwargs):
         kwargs.setdefault("pool", BasePool(2))
@@ -101,75 +84,6 @@ def foo_periodic_task():
     return "foo"
 
 
-class MockLogger(object):
-
-    def __init__(self):
-        self.logged = []
-
-    def critical(self, msg, *args, **kwargs):
-        self.logged.append(msg)
-
-    def info(self, msg, *args, **kwargs):
-        self.logged.append(msg)
-
-    def error(self, msg, *args, **kwargs):
-        self.logged.append(msg)
-
-    def debug(self, msg, *args, **kwargs):
-        self.logged.append(msg)
-
-
-class MockBackend(object):
-    _acked = False
-
-    def basic_ack(self, delivery_tag):
-        self._acked = True
-
-
-class MockPool(BasePool):
-    _terminated = False
-    _stopped = False
-
-    def __init__(self, *args, **kwargs):
-        self.raise_regular = kwargs.get("raise_regular", False)
-        self.raise_base = kwargs.get("raise_base", False)
-        self.raise_SystemTerminate = kwargs.get("raise_SystemTerminate",
-                                                False)
-
-    def apply_async(self, *args, **kwargs):
-        if self.raise_regular:
-            raise KeyError("some exception")
-        if self.raise_base:
-            raise KeyboardInterrupt("Ctrl+c")
-        if self.raise_SystemTerminate:
-            raise SystemTerminate()
-
-    def start(self):
-        pass
-
-    def stop(self):
-        self._stopped = True
-        return True
-
-    def terminate(self):
-        self._terminated = True
-        self.stop()
-
-
-class MockController(object):
-
-    def __init__(self, w, *args, **kwargs):
-        self._w = w
-        self._stopped = False
-
-    def start(self):
-        self._w["started"] = True
-        self._stopped = False
-
-    def stop(self):
-        self._stopped = True
-
-
 def create_message(backend, **data):
     data.setdefault("id", gen_unique_id())
     return Message(backend, body=pickle.dumps(dict(**data)),
@@ -231,15 +145,8 @@ class test_QoS(unittest.TestCase):
         threaded([add, sub])  # n = 2
         self.assertEqual(qos.value, 1000)
 
-    class MockConsumer(object):
-        prefetch_count = 0
-
-        def qos(self, prefetch_size=0, prefetch_count=0, apply_global=False):
-            self.prefetch_count = prefetch_count
-
     def test_exceeds_short(self):
-        consumer = self.MockConsumer()
-        qos = QoS(consumer, PREFETCH_COUNT_MAX - 1,
+        qos = QoS(Mock(), PREFETCH_COUNT_MAX - 1,
                 current_app.log.get_default_logger())
         qos.update()
         self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1)
@@ -253,17 +160,17 @@ class test_QoS(unittest.TestCase):
         self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1)
 
     def test_consumer_increment_decrement(self):
-        consumer = self.MockConsumer()
+        consumer = Mock()
         qos = QoS(consumer, 10, current_app.log.get_default_logger())
         qos.update()
         self.assertEqual(qos.value, 10)
-        self.assertEqual(consumer.prefetch_count, 10)
+        self.assertIn({"prefetch_count": 10}, consumer.qos.call_args)
         qos.decrement()
         self.assertEqual(qos.value, 9)
-        self.assertEqual(consumer.prefetch_count, 9)
+        self.assertIn({"prefetch_count": 9}, consumer.qos.call_args)
         qos.decrement_eventually()
         self.assertEqual(qos.value, 8)
-        self.assertEqual(consumer.prefetch_count, 9)
+        self.assertIn({"prefetch_count": 9}, consumer.qos.call_args)
 
         # Does not decrement 0 value
         qos.value = 0
@@ -340,7 +247,7 @@ class test_Consumer(unittest.TestCase):
     def test_receive_message_unknown(self):
         l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                            send_events=False)
-        backend = MockBackend()
+        backend = Mock()
         m = create_message(backend, unknown={"baz": "!!!"})
         l.event_dispatcher = MockEventDispatcher()
         l.pidbox_node = MockNode()
@@ -356,7 +263,7 @@ class test_Consumer(unittest.TestCase):
     def test_receive_message_eta_OverflowError(self):
         l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                              send_events=False)
-        backend = MockBackend()
+        backend = Mock()
         called = [False]
 
         def to_timestamp(d):
@@ -379,20 +286,20 @@ class test_Consumer(unittest.TestCase):
             timer2.to_timestamp = prev
 
     def test_receive_message_InvalidTaskError(self):
-        logger = MockLogger()
+        logger = Mock()
         l = MyKombuConsumer(self.ready_queue, self.eta_schedule, logger,
                            send_events=False)
-        backend = MockBackend()
+        backend = Mock()
         m = create_message(backend, task=foo_task.name,
             args=(1, 2), kwargs="foobarbaz", id=1)
         l.event_dispatcher = MockEventDispatcher()
         l.pidbox_node = MockNode()
 
         l.receive_message(m.decode(), m)
-        self.assertIn("Invalid task ignored", logger.logged[0])
+        self.assertIn("Invalid task ignored", logger.error.call_args[0][0])
 
     def test_on_decode_error(self):
-        logger = MockLogger()
+        logger = Mock()
         l = MyKombuConsumer(self.ready_queue, self.eta_schedule, logger,
                            send_events=False)
 
@@ -408,12 +315,13 @@ class test_Consumer(unittest.TestCase):
         message = MockMessage()
         l.on_decode_error(message, KeyError("foo"))
         self.assertTrue(message.acked)
-        self.assertIn("Can't decode message body", logger.logged[0])
+        self.assertIn("Can't decode message body",
+                      logger.critical.call_args[0][0])
 
     def test_receieve_message(self):
         l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                            send_events=False)
-        backend = MockBackend()
+        backend = Mock()
         m = create_message(backend, task=foo_task.name,
                            args=[2, 4, 8], kwargs={})
 
@@ -520,7 +428,7 @@ class test_Consumer(unittest.TestCase):
 
         l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                              send_events=False)
-        backend = MockBackend()
+        backend = Mock()
         m = create_message(backend, task=foo_task.name,
                            eta=datetime.now().isoformat(),
                            args=[2, 4, 8], kwargs={})
@@ -544,7 +452,7 @@ class test_Consumer(unittest.TestCase):
         ready_queue = FastQueue()
         l = MyKombuConsumer(ready_queue, self.eta_schedule, self.logger,
                            send_events=False)
-        backend = MockBackend()
+        backend = Mock()
         id = gen_unique_id()
         t = create_message(backend, task=foo_task.name, args=[2, 4, 8],
                            kwargs={}, id=id)
@@ -557,7 +465,7 @@ class test_Consumer(unittest.TestCase):
     def test_receieve_message_not_registered(self):
         l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
-        backend = MockBackend()
+        backend = Mock()
         m = create_message(backend, task="x.X.31x", args=[2, 4, 8], kwargs={})
 
         l.event_dispatcher = MockEventDispatcher()
@@ -569,7 +477,7 @@ class test_Consumer(unittest.TestCase):
         l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
         l.event_dispatcher = MockEventDispatcher()
-        backend = MockBackend()
+        backend = Mock()
         m = create_message(backend, task=foo_task.name,
                            args=[2, 4, 8], kwargs={},
                            eta=(datetime.now() +
@@ -619,8 +527,8 @@ class test_Consumer(unittest.TestCase):
 
         l = _Consumer(self.ready_queue, self.eta_schedule, self.logger,
                       send_events=False, init_callback=init_callback)
-        l.task_consumer = MockConsumer()
-        l.broadcast_consumer = MockConsumer()
+        l.task_consumer = Mock()
+        l.broadcast_consumer = Mock()
         l.qos = _QoS()
         l.connection = BrokerConnection()
         l.iterations = 0
@@ -641,8 +549,8 @@ class test_Consumer(unittest.TestCase):
         l = _Consumer(self.ready_queue, self.eta_schedule, self.logger,
                       send_events=False, init_callback=init_callback)
         l.qos = _QoS()
-        l.task_consumer = MockConsumer()
-        l.broadcast_consumer = MockConsumer()
+        l.task_consumer = Mock()
+        l.broadcast_consumer = Mock()
         l.connection = BrokerConnection()
 
         def raises_socket_error(limit=None):
@@ -662,7 +570,7 @@ class test_WorkController(AppCase):
 
     def create_worker(self, **kw):
         worker = WorkController(concurrency=1, loglevel=0, **kw)
-        worker.logger = MockLogger()
+        worker.logger = Mock()
         return worker
 
     def test_process_initializer(self):
@@ -759,7 +667,7 @@ class test_WorkController(AppCase):
 
     def test_on_timer_error(self):
         worker = WorkController(concurrency=1, loglevel=0)
-        worker.logger = MockLogger()
+        worker.logger = Mock()
 
         try:
             raise KeyError("foo")
@@ -767,32 +675,34 @@ class test_WorkController(AppCase):
             exc_info = sys.exc_info()
 
         worker.on_timer_error(exc_info)
-        logged = worker.logger.logged[0]
+        logged = worker.logger.error.call_args[0][0]
         self.assertIn("KeyError", logged)
 
     def test_on_timer_tick(self):
         worker = WorkController(concurrency=1, loglevel=10)
-        worker.logger = MockLogger()
+        worker.logger = Mock()
         worker.timer_debug = worker.logger.debug
 
         worker.on_timer_tick(30.0)
-        logged = worker.logger.logged[0]
+        logged = worker.logger.debug.call_args[0][0]
         self.assertIn("30.0", logged)
 
     def test_process_task(self):
         worker = self.worker
-        worker.pool = MockPool()
-        backend = MockBackend()
+        worker.pool = Mock()
+        backend = Mock()
         m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
                            kwargs={})
         task = TaskRequest.from_message(m, m.decode())
         worker.process_task(task)
+        self.assertEqual(worker.pool.apply_async.call_count, 1)
         worker.pool.stop()
 
     def test_process_task_raise_base(self):
         worker = self.worker
-        worker.pool = MockPool(raise_base=True)
-        backend = MockBackend()
+        worker.pool = Mock()
+        worker.pool.apply_async.side_effect = KeyboardInterrupt("Ctrl+C")
+        backend = Mock()
         m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
                            kwargs={})
         task = TaskRequest.from_message(m, m.decode())
@@ -803,8 +713,9 @@ class test_WorkController(AppCase):
 
     def test_process_task_raise_SystemTerminate(self):
         worker = self.worker
-        worker.pool = MockPool(raise_SystemTerminate=True)
-        backend = MockBackend()
+        worker.pool = Mock()
+        worker.pool.apply_async.side_effect = SystemTerminate()
+        backend = Mock()
         m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
                            kwargs={})
         task = TaskRequest.from_message(m, m.decode())
@@ -815,8 +726,9 @@ class test_WorkController(AppCase):
 
     def test_process_task_raise_regular(self):
         worker = self.worker
-        worker.pool = MockPool(raise_regular=True)
-        backend = MockBackend()
+        worker.pool = Mock()
+        worker.pool.apply_async.side_effect = KeyError("some exception")
+        backend = Mock()
         m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
                            kwargs={})
         task = TaskRequest.from_message(m, m.decode())
@@ -824,43 +736,26 @@ class test_WorkController(AppCase):
         worker.pool.stop()
 
     def test_start_catches_base_exceptions(self):
-
-        class Component(object):
-            stopped = False
-            terminated = False
-
-            def __init__(self, exc):
-                self.exc = exc
-
-            def start(self):
-                raise self.exc
-
-            def terminate(self):
-                self.terminated = True
-
-            def stop(self):
-                self.stopped = True
-
         worker1 = self.create_worker()
-        worker1.components = [Component(SystemTerminate())]
+        stc = Mock()
+        stc.start.side_effect = SystemTerminate()
+        worker1.components = [stc]
         self.assertRaises(SystemExit, worker1.start)
-        self.assertTrue(worker1.components[0].terminated)
+        self.assertTrue(stc.terminate.call_count)
 
         worker2 = self.create_worker()
-        worker2.components = [Component(SystemExit())]
+        sec = Mock()
+        sec.start.side_effect = SystemExit()
+        sec.terminate = None
+        worker2.components = [sec]
         self.assertRaises(SystemExit, worker2.start)
-        self.assertTrue(worker2.components[0].stopped)
+        self.assertTrue(sec.stop.call_count)
 
     def test_state_db(self):
         from celery.worker import state
         Persistent = state.Persistent
 
-        class MockPersistent(Persistent):
-
-            def _load(self):
-                return {}
-
-        state.Persistent = MockPersistent
+        state.Persistent = Mock()
         try:
             worker = self.create_worker(db="statefilename")
             self.assertTrue(worker._finalize_db)
@@ -878,37 +773,27 @@ class test_WorkController(AppCase):
 
     def test_start__stop(self):
         worker = self.worker
-        w1 = {"started": False}
-        w2 = {"started": False}
-        w3 = {"started": False}
-        w4 = {"started": False}
-        worker.components = [MockController(w1), MockController(w2),
-                             MockController(w3), MockController(w4)]
+        worker.components = [Mock(), Mock(), Mock(), Mock()]
 
         worker.start()
-        for w in (w1, w2, w3, w4):
-            self.assertTrue(w["started"])
-        self.assertTrue(worker._running, len(worker.components))
+        for w in worker.components:
+            self.assertTrue(w.start.call_count)
         worker.stop()
         for component in worker.components:
-            self.assertTrue(component._stopped)
+            self.assertTrue(w.stop.call_count)
 
     def test_start__terminate(self):
         worker = self.worker
-        w1 = {"started": False}
-        w2 = {"started": False}
-        w3 = {"started": False}
-        w4 = {"started": False}
-        worker.components = [MockController(w1), MockController(w2),
-                             MockController(w3), MockController(w4),
-                             MockPool()]
+        worker.components = [Mock(), Mock(), Mock(), Mock(), Mock()]
+        for component in worker.components[:3]:
+            component.terminate = None
 
         worker.start()
-        for w in (w1, w2, w3, w4):
-            self.assertTrue(w["started"])
+        for w in worker.components[:3]:
+            self.assertTrue(w.start.call_count)
         self.assertTrue(worker._running, len(worker.components))
         self.assertEqual(worker._state, RUN)
         worker.terminate()
-        for component in worker.components:
-            self.assertTrue(component._stopped)
-        self.assertTrue(worker.components[4]._terminated)
+        for component in worker.components[:3]:
+            self.assertTrue(component.stop.call_count)
+        self.assertTrue(worker.components[4].terminate.call_count)

+ 10 - 22
celery/tests/test_worker/test_worker_control.py

@@ -4,6 +4,7 @@ from celery.tests.utils import unittest
 from datetime import datetime, timedelta
 
 from kombu import pidbox
+from mock import Mock
 
 from celery.utils.timer2 import Timer
 
@@ -26,22 +27,6 @@ def mytask():
     pass
 
 
-class Dispatcher(object):
-    enabled = None
-
-    def __init__(self, *args, **kwargs):
-        self.sent = []
-
-    def enable(self):
-        self.enabled = True
-
-    def disable(self):
-        self.enabled = False
-
-    def send(self, event, **fields):
-        self.sent.append(event)
-
-
 class Consumer(object):
 
     def __init__(self):
@@ -52,7 +37,7 @@ class Consumer(object):
                                          kwargs={}))
         self.eta_schedule = Timer()
         self.app = app_or_default()
-        self.event_dispatcher = Dispatcher()
+        self.event_dispatcher = Mock()
 
         from celery.concurrency.base import BasePool
         self.pool = BasePool(10)
@@ -83,8 +68,9 @@ class test_ControlPanel(unittest.TestCase):
         panel = self.create_panel(consumer=consumer)
         consumer.event_dispatcher.enabled = False
         panel.handle("enable_events")
-        self.assertEqual(consumer.event_dispatcher.enabled, True)
-        self.assertIn("worker-online", consumer.event_dispatcher.sent)
+        self.assertTrue(consumer.event_dispatcher.enable.call_count)
+        self.assertIn(("worker-online", ),
+                consumer.event_dispatcher.send.call_args)
         self.assertTrue(panel.handle("enable_events")["ok"])
 
     def test_disable_events(self):
@@ -92,8 +78,9 @@ class test_ControlPanel(unittest.TestCase):
         panel = self.create_panel(consumer=consumer)
         consumer.event_dispatcher.enabled = True
         panel.handle("disable_events")
-        self.assertEqual(consumer.event_dispatcher.enabled, False)
-        self.assertIn("worker-offline", consumer.event_dispatcher.sent)
+        self.assertTrue(consumer.event_dispatcher.disable.call_count)
+        self.assertIn(("worker-offline", ),
+                      consumer.event_dispatcher.send.call_args)
         self.assertTrue(panel.handle("disable_events")["ok"])
 
     def test_heartbeat(self):
@@ -101,7 +88,8 @@ class test_ControlPanel(unittest.TestCase):
         panel = self.create_panel(consumer=consumer)
         consumer.event_dispatcher.enabled = True
         panel.handle("heartbeat")
-        self.assertIn("worker-heartbeat", consumer.event_dispatcher.sent)
+        self.assertIn(("worker-heartbeat", ),
+                      consumer.event_dispatcher.send.call_args)
 
     def test_dump_tasks(self):
         info = "\n".join(self.panel.handle("dump_tasks"))

+ 4 - 4
celery/tests/test_worker/test_worker_mediator.py

@@ -2,6 +2,8 @@ from celery.tests.utils import unittest
 
 from Queue import Queue
 
+from mock import Mock
+
 from celery.utils import gen_unique_id
 from celery.worker.mediator import Mediator
 from celery.worker.state import revoked as revoked_tasks
@@ -11,13 +13,11 @@ class MockTask(object):
     hostname = "harness.com"
     task_id = 1234
     task_name = "mocktask"
-    acked = False
 
     def __init__(self, value, **kwargs):
         self.value = value
 
-    def on_ack(self):
-        self.acked = True
+    on_ack = Mock()
 
     def revoked(self):
         if self.task_id in revoked_tasks:
@@ -96,4 +96,4 @@ class test_Mediator(unittest.TestCase):
         m.move()
 
         self.assertNotIn("value", got)
-        self.assertTrue(t.acked)
+        self.assertTrue(t.on_ack.call_count)

+ 1 - 1
celery/worker/__init__.py

@@ -297,7 +297,7 @@ class WorkController(object):
                     what, component.__class__.__name__))
             stop = component.stop
             if not warm:
-                stop = getattr(component, "terminate", stop)
+                stop = getattr(component, "terminate", None) or stop
             stop()
 
         self.priority_timer.stop()