Browse Source

96% coverage for celery.worker.listener

Ask Solem 15 years ago
parent
commit
2df3c25622
1 changed files with 132 additions and 17 deletions
  1. 132 17
      celery/tests/test_worker.py

+ 132 - 17
celery/tests/test_worker.py

@@ -1,24 +1,26 @@
+import socket
 import unittest2 as unittest
-from Queue import Empty
+
 from datetime import datetime, timedelta
 from multiprocessing import get_logger
+from Queue import Empty
 
-from carrot.connection import BrokerConnection
 from carrot.backends.base import BaseMessage
+from carrot.connection import BrokerConnection
 
 from celery import conf
+from celery.decorators import task as task_dec
+from celery.decorators import periodic_task as periodic_task_dec
+from celery.serialization import pickle
 from celery.utils import gen_unique_id
 from celery.worker import WorkController
-from celery.worker.job import TaskRequest
 from celery.worker.buckets import FastQueue
+from celery.worker.job import TaskRequest
 from celery.worker.listener import CarrotListener, QoS, RUN
 from celery.worker.scheduler import Scheduler
-from celery.decorators import task as task_dec
-from celery.decorators import periodic_task as periodic_task_dec
-from celery.serialization import pickle
 
-from celery.tests.utils import execute_context
 from celery.tests.compat import catch_warnings
+from celery.tests.utils import execute_context
 
 
 class PlaceHolder(object):
@@ -62,17 +64,20 @@ def foo_periodic_task():
 
 class MockLogger(object):
 
-    def critical(self, *args, **kwargs):
-        pass
+    def __init__(self):
+        self.logged = []
 
-    def info(self, *args, **kwargs):
-        pass
+    def critical(self, msg, *args, **kwargs):
+        self.logged.append(msg)
 
-    def error(self, *args, **kwargs):
-        pass
+    def info(self, msg, *args, **kwargs):
+        self.logged.append(msg)
 
-    def debug(self, *args, **kwargs):
-        pass
+    def error(self, msg, *args, **kwargs):
+        self.logged.append(msg)
+
+    def debug(self, msg, *args, **kwargs):
+        self.logged.append(msg)
 
 
 class MockBackend(object):
@@ -123,7 +128,29 @@ def create_message(backend, **data):
                        content_encoding="binary")
 
 
-class TestCarrotListener(unittest.TestCase):
+class test_QoS(unittest.TestCase):
+
+    class MockConsumer(object):
+        prefetch_count = 0
+
+        def qos(self, prefetch_size=0, prefetch_count=0, apply_global=False):
+            self.prefetch_count = prefetch_count
+
+    def test_decrement(self):
+        consumer = self.MockConsumer()
+        qos = QoS(consumer, 10, get_logger())
+        qos.update()
+        self.assertEqual(int(qos.value), 10)
+        self.assertEqual(consumer.prefetch_count, 10)
+        qos.decrement()
+        self.assertEqual(int(qos.value), 9)
+        self.assertEqual(consumer.prefetch_count, 9)
+        qos.decrement_eventually()
+        self.assertEqual(int(qos.value), 8)
+        self.assertEqual(consumer.prefetch_count, 9)
+
+
+class test_CarrotListener(unittest.TestCase):
 
     def setUp(self):
         self.ready_queue = FastQueue()
@@ -232,6 +259,38 @@ class TestCarrotListener(unittest.TestCase):
         context = catch_warnings(record=True)
         execute_context(context, with_catch_warnings)
 
+    def test_receive_message_InvalidTaskError(self):
+        logger = MockLogger()
+        l = CarrotListener(self.ready_queue, self.eta_schedule, logger,
+                           send_events=False)
+        backend = MockBackend()
+        m = create_message(backend, task=foo_task.name,
+            args=(1, 2), kwargs="foobarbaz", id=1)
+        l.event_dispatcher = MockEventDispatcher()
+        l.control_dispatch = MockControlDispatch()
+
+        l.receive_message(m.decode(), m)
+        self.assertIn("Invalid task ignored", logger.logged[0])
+
+    def test_on_decode_error(self):
+        logger = MockLogger()
+        l = CarrotListener(self.ready_queue, self.eta_schedule, logger,
+                           send_events=False)
+
+        class MockMessage(object):
+            content_type = "application/x-msgpack"
+            content_encoding = "binary"
+            body = "foobarbaz"
+            acked = False
+
+            def ack(self):
+                self.acked = True
+
+        message = MockMessage()
+        l.on_decode_error(message, KeyError("foo"))
+        self.assertTrue(message.acked)
+        self.assertIn("Message decoding error", logger.logged[0])
+
     def test_receieve_message(self):
         l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
                            send_events=False)
@@ -331,8 +390,64 @@ class TestCarrotListener(unittest.TestCase):
         self.assertEqual(task.execute(), 2 * 4 * 8)
         self.assertRaises(Empty, self.ready_queue.get_nowait)
 
+    def test_start__consume_messages(self):
+
+        class _QoS(object):
+            prev = 3
+            next = 4
+
+            def update(self):
+                self.prev = self.next
+
+        class _Listener(CarrotListener):
+            iterations = 0
+            wait_method = None
+
+            def reset_connection(self):
+                if self.iterations >= 1:
+                    raise KeyError("foo")
+
+            def _detect_wait_method(self):
+                return self.wait_method
+
+        called_back = [False]
+        def init_callback(listener):
+            called_back[0] = True
+
+
+        l = _Listener(self.ready_queue, self.eta_schedule, self.logger,
+                      send_events=False, init_callback=init_callback)
+        l.qos = _QoS()
+
+        def raises_KeyError(limit=None):
+            yield True
+            l.iterations = 1
+            raise KeyError("foo")
+
+        l.wait_method = raises_KeyError
+        self.assertRaises(KeyError, l.start)
+        self.assertTrue(called_back[0])
+        self.assertEqual(l.iterations, 1)
+        self.assertEqual(l.qos.prev, l.qos.next)
+
+        l = _Listener(self.ready_queue, self.eta_schedule, self.logger,
+                      send_events=False, init_callback=init_callback)
+        l.qos = _QoS()
+        def raises_socket_error(limit=None):
+            yield True
+            l.iterations = 1
+            raise socket.error("foo")
+
+        l.wait_method = raises_socket_error
+        self.assertRaises(KeyError, l.start)
+        self.assertTrue(called_back[0])
+        self.assertEqual(l.iterations, 1)
+
+
+
+
 
-class TestWorkController(unittest.TestCase):
+class test_WorkController(unittest.TestCase):
 
     def setUp(self):
         self.worker = WorkController(concurrency=1, loglevel=0)