Bläddra i källkod

More test refactorings

Ask Solem 14 år sedan
förälder
incheckning
c98feeb7a6
1 ändrade filer med 48 tillägg och 120 borttagningar
  1. 48 120
      celery/tests/test_worker/test_worker.py

+ 48 - 120
celery/tests/test_worker/test_worker.py

@@ -7,7 +7,7 @@ from Queue import Empty
 
 from kombu.transport.base import Message
 from kombu.connection import BrokerConnection
-from mock import Mock
+from mock import Mock, patch
 
 from celery import current_app
 from celery.concurrency.base import BasePool
@@ -262,38 +262,28 @@ class test_Consumer(unittest.TestCase):
         context = catch_warnings(record=True)
         execute_context(context, with_catch_warnings)
 
-    def test_receive_message_eta_OverflowError(self):
+    @patch("celery.utils.timer2.to_timestamp")
+    def test_receive_message_eta_OverflowError(self, to_timestamp):
+        to_timestamp.side_effect = OverflowError()
         l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                              send_events=False)
-        backend = Mock()
-        called = [False]
-
-        def to_timestamp(d):
-            called[0] = True
-            raise OverflowError()
-
-        m = create_message(backend, task=foo_task.name,
+        m = create_message(Mock(), task=foo_task.name,
                                     args=("2, 2"),
                                     kwargs={},
                                     eta=datetime.now().isoformat())
         l.event_dispatcher = Mock()
         l.pidbox_node = MockNode()
 
-        prev, timer2.to_timestamp = timer2.to_timestamp, to_timestamp
-        try:
-            l.receive_message(m.decode(), m)
-            self.assertTrue(m.acknowledged)
-            self.assertTrue(called[0])
-        finally:
-            timer2.to_timestamp = prev
+        l.receive_message(m.decode(), m)
+        self.assertTrue(m.acknowledged)
+        self.assertTrue(to_timestamp.call_count)
 
     def test_receive_message_InvalidTaskError(self):
         logger = Mock()
         l = MyKombuConsumer(self.ready_queue, self.eta_schedule, logger,
                            send_events=False)
-        backend = Mock()
-        m = create_message(backend, task=foo_task.name,
-            args=(1, 2), kwargs="foobarbaz", id=1)
+        m = create_message(Mock(), task=foo_task.name,
+                           args=(1, 2), kwargs="foobarbaz", id=1)
         l.event_dispatcher = Mock()
         l.pidbox_node = MockNode()
 
@@ -306,26 +296,21 @@ class test_Consumer(unittest.TestCase):
         l = MyKombuConsumer(self.ready_queue, self.eta_schedule, logger,
                            send_events=False)
 
-        class MockMessage(object):
+        class MockMessage(Mock):
             content_type = "application/x-msgpack"
             content_encoding = "binary"
             body = "foobarbaz"
-            acked = False
-
-            def ack(self):
-                self.acked = True
 
         message = MockMessage()
         l.on_decode_error(message, KeyError("foo"))
-        self.assertTrue(message.acked)
+        self.assertTrue(message.ack.call_count)
         self.assertIn("Can't decode message body",
                       logger.critical.call_args[0][0])
 
     def test_receieve_message(self):
         l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                            send_events=False)
-        backend = Mock()
-        m = create_message(backend, task=foo_task.name,
+        m = create_message(Mock(), task=foo_task.name,
                            args=[2, 4, 8], kwargs={})
 
         l.event_dispatcher = Mock()
@@ -363,51 +348,31 @@ class test_Consumer(unittest.TestCase):
             def drain_events(self, **kwargs):
                 self.obj.connection = None
 
-        class Consumer(object):
-            consuming = False
-            prefetch_count = 0
-
-            def consume(self):
-                self.consuming = True
-
-            def qos(self, prefetch_size=0, prefetch_count=0,
-                            apply_global=False):
-                self.prefetch_count = prefetch_count
-
         l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                              send_events=False)
         l.connection = Connection()
         l.connection.obj = l
-        l.task_consumer = Consumer()
+        l.task_consumer = Mock()
         l.qos = QoS(l.task_consumer, 10, l.logger)
 
         l.consume_messages()
         l.consume_messages()
-        self.assertTrue(l.task_consumer.consuming)
-        self.assertEqual(l.task_consumer.prefetch_count, 10)
-
+        self.assertTrue(l.task_consumer.consume.call_count)
+        l.task_consumer.qos.assert_called_with(prefetch_count=10)
         l.qos.decrement()
         l.consume_messages()
-        self.assertEqual(l.task_consumer.prefetch_count, 9)
+        l.task_consumer.qos.assert_called_with(prefetch_count=9)
 
     def test_maybe_conn_error(self):
-
-        def raises(error):
-
-            def fun():
-                raise error
-
-            return fun
-
         l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                              send_events=False)
         l.connection_errors = (KeyError, )
         l.channel_errors = (SyntaxError, )
-        l.maybe_conn_error(raises(AttributeError("foo")))
-        l.maybe_conn_error(raises(KeyError("foo")))
-        l.maybe_conn_error(raises(SyntaxError("foo")))
+        l.maybe_conn_error(Mock(side_effect=AttributeError("foo")))
+        l.maybe_conn_error(Mock(side_effect=KeyError("foo")))
+        l.maybe_conn_error(Mock(side_effect=SyntaxError("foo")))
         self.assertRaises(IndexError, l.maybe_conn_error,
-                raises(IndexError("foo")))
+                Mock(side_effect=IndexError("foo")))
 
     def test_apply_eta_task(self):
         from celery.worker import state
@@ -423,21 +388,13 @@ class test_Consumer(unittest.TestCase):
         self.assertIs(self.ready_queue.get_nowait(), task)
 
     def test_receieve_message_eta_isoformat(self):
-
-        class MockConsumer(object):
-            prefetch_count_incremented = False
-
-            def qos(self, **kwargs):
-                self.prefetch_count_incremented = True
-
         l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                              send_events=False)
-        backend = Mock()
-        m = create_message(backend, task=foo_task.name,
+        m = create_message(Mock(), task=foo_task.name,
                            eta=datetime.now().isoformat(),
                            args=[2, 4, 8], kwargs={})
 
-        l.task_consumer = MockConsumer()
+        l.task_consumer = Mock()
         l.qos = QoS(l.task_consumer, l.initial_prefetch_count, l.logger)
         l.event_dispatcher = Mock()
         l.receive_message(m.decode(), m)
@@ -449,7 +406,7 @@ class test_Consumer(unittest.TestCase):
             if item.args[0].task_name == foo_task.name:
                 found = True
         self.assertTrue(found)
-        self.assertTrue(l.task_consumer.prefetch_count_incremented)
+        self.assertTrue(l.task_consumer.qos.call_count)
         l.eta_schedule.stop()
 
     def test_revoke(self):
@@ -519,17 +476,12 @@ class test_Consumer(unittest.TestCase):
 
         class _Consumer(MyKombuConsumer):
             iterations = 0
-            wait_method = None
 
             def reset_connection(self):
                 if self.iterations >= 1:
                     raise KeyError("foo")
 
-        called_back = [False]
-
-        def init_callback(consumer):
-            called_back[0] = True
-
+        init_callback = Mock()
         l = _Consumer(self.ready_queue, self.eta_schedule, self.logger,
                       send_events=False, init_callback=init_callback)
         l.task_consumer = Mock()
@@ -547,25 +499,21 @@ class test_Consumer(unittest.TestCase):
 
         l.consume_messages = raises_KeyError
         self.assertRaises(KeyError, l.start)
-        self.assertTrue(called_back[0])
+        self.assertTrue(init_callback.call_count)
         self.assertEqual(l.iterations, 1)
         self.assertEqual(l.qos.prev, l.qos.value)
 
+        init_callback.reset_mock()
         l = _Consumer(self.ready_queue, self.eta_schedule, self.logger,
                       send_events=False, init_callback=init_callback)
         l.qos = _QoS()
         l.task_consumer = Mock()
         l.broadcast_consumer = Mock()
         l.connection = BrokerConnection()
-
-        def raises_socket_error(limit=None):
-            l.iterations = 1
-            raise socket.error("foo")
-
-        l.consume_messages = raises_socket_error
+        l.consume_messages = Mock(side_effect=socket.error("foo"))
         self.assertRaises(socket.error, l.start)
-        self.assertTrue(called_back[0])
-        self.assertEqual(l.iterations, 1)
+        self.assertTrue(init_callback.call_count)
+        self.assertTrue(l.consume_messages.call_count)
 
 
 class test_WorkController(AppCase):
@@ -578,55 +526,35 @@ class test_WorkController(AppCase):
         worker.logger = Mock()
         return worker
 
-    def test_process_initializer(self):
+    @patch("celery.platforms.reset_signal")
+    @patch("celery.platforms.ignore_signal")
+    @patch("celery.platforms.set_mp_process_title")
+    def test_process_initializer(self, set_mp_process_title, ignore_signal,
+            reset_signal):
         from celery import Celery
-        from celery import platforms
         from celery import signals
         from celery.app import _tls
         from celery.worker import process_initializer
         from celery.worker import WORKER_SIGRESET, WORKER_SIGIGNORE
 
-        ignored_signals = []
-        reset_signals = []
-        worker_init = [False]
-        default_app = current_app
-        app = Celery(loader="default", set_as_current=False)
-
-        class Loader(object):
-
-            def init_worker(self):
-                worker_init[0] = True
-        app.loader = Loader()
-
         def on_worker_process_init(**kwargs):
             on_worker_process_init.called = True
         on_worker_process_init.called = False
         signals.worker_process_init.connect(on_worker_process_init)
 
-        def set_mp_process_title(title, hostname=None):
-            set_mp_process_title.called = (title, hostname)
-        set_mp_process_title.called = ()
-
-        pignore_signal = platforms.ignore_signal
-        preset_signal = platforms.reset_signal
-        psetproctitle = platforms.set_mp_process_title
-        platforms.ignore_signal = lambda sig: ignored_signals.append(sig)
-        platforms.reset_signal = lambda sig: reset_signals.append(sig)
-        platforms.set_mp_process_title = set_mp_process_title
-        try:
-            process_initializer(app, "awesome.worker.com")
-            self.assertItemsEqual(ignored_signals, WORKER_SIGIGNORE)
-            self.assertItemsEqual(reset_signals, WORKER_SIGRESET)
-            self.assertTrue(worker_init[0])
-            self.assertTrue(on_worker_process_init.called)
-            self.assertIs(_tls.current_app, app)
-            self.assertTupleEqual(set_mp_process_title.called,
-                                  ("celeryd", "awesome.worker.com"))
-        finally:
-            platforms.ignore_signal = pignore_signal
-            platforms.reset_signal = preset_signal
-            platforms.set_mp_process_title = psetproctitle
-            default_app.set_current()
+        app = Celery(loader=Mock(), set_as_current=False)
+        process_initializer(app, "awesome.worker.com")
+        for ignoresig in WORKER_SIGIGNORE:
+            self.assertIn(((ignoresig, ), {}),
+                            ignore_signal.call_args_list)
+        for resetsig in WORKER_SIGRESET:
+            self.assertIn(((resetsig, ), {}),
+                            reset_signal.call_args_list)
+        self.assertTrue(app.loader.init_worker.call_count)
+        self.assertTrue(on_worker_process_init.called)
+        self.assertIs(_tls.current_app, app)
+        set_mp_process_title.assert_called_with("celeryd",
+                        hostname="awesome.worker.com")
 
     def test_with_rate_limits_disabled(self):
         worker = WorkController(concurrency=1, loglevel=0,