|  | @@ -1,24 +1,26 @@
 | 
	
		
			
				|  |  | +import socket
 | 
	
		
			
				|  |  |  import unittest2 as unittest
 | 
	
		
			
				|  |  | -from Queue import Empty
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  from datetime import datetime, timedelta
 | 
	
		
			
				|  |  |  from multiprocessing import get_logger
 | 
	
		
			
				|  |  | +from Queue import Empty
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -from carrot.connection import BrokerConnection
 | 
	
		
			
				|  |  |  from carrot.backends.base import BaseMessage
 | 
	
		
			
				|  |  | +from carrot.connection import BrokerConnection
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  from celery import conf
 | 
	
		
			
				|  |  | +from celery.decorators import task as task_dec
 | 
	
		
			
				|  |  | +from celery.decorators import periodic_task as periodic_task_dec
 | 
	
		
			
				|  |  | +from celery.serialization import pickle
 | 
	
		
			
				|  |  |  from celery.utils import gen_unique_id
 | 
	
		
			
				|  |  |  from celery.worker import WorkController
 | 
	
		
			
				|  |  | -from celery.worker.job import TaskRequest
 | 
	
		
			
				|  |  |  from celery.worker.buckets import FastQueue
 | 
	
		
			
				|  |  | +from celery.worker.job import TaskRequest
 | 
	
		
			
				|  |  |  from celery.worker.listener import CarrotListener, QoS, RUN
 | 
	
		
			
				|  |  |  from celery.worker.scheduler import Scheduler
 | 
	
		
			
				|  |  | -from celery.decorators import task as task_dec
 | 
	
		
			
				|  |  | -from celery.decorators import periodic_task as periodic_task_dec
 | 
	
		
			
				|  |  | -from celery.serialization import pickle
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -from celery.tests.utils import execute_context
 | 
	
		
			
				|  |  |  from celery.tests.compat import catch_warnings
 | 
	
		
			
				|  |  | +from celery.tests.utils import execute_context
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  class PlaceHolder(object):
 | 
	
	
		
			
				|  | @@ -62,17 +64,20 @@ def foo_periodic_task():
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  class MockLogger(object):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    def critical(self, *args, **kwargs):
 | 
	
		
			
				|  |  | -        pass
 | 
	
		
			
				|  |  | +    def __init__(self):
 | 
	
		
			
				|  |  | +        self.logged = []
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    def info(self, *args, **kwargs):
 | 
	
		
			
				|  |  | -        pass
 | 
	
		
			
				|  |  | +    def critical(self, msg, *args, **kwargs):
 | 
	
		
			
				|  |  | +        self.logged.append(msg)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    def error(self, *args, **kwargs):
 | 
	
		
			
				|  |  | -        pass
 | 
	
		
			
				|  |  | +    def info(self, msg, *args, **kwargs):
 | 
	
		
			
				|  |  | +        self.logged.append(msg)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    def debug(self, *args, **kwargs):
 | 
	
		
			
				|  |  | -        pass
 | 
	
		
			
				|  |  | +    def error(self, msg, *args, **kwargs):
 | 
	
		
			
				|  |  | +        self.logged.append(msg)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def debug(self, msg, *args, **kwargs):
 | 
	
		
			
				|  |  | +        self.logged.append(msg)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  class MockBackend(object):
 | 
	
	
		
			
				|  | @@ -123,7 +128,29 @@ def create_message(backend, **data):
 | 
	
		
			
				|  |  |                         content_encoding="binary")
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -class TestCarrotListener(unittest.TestCase):
 | 
	
		
			
				|  |  | +class test_QoS(unittest.TestCase):
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    class MockConsumer(object):
 | 
	
		
			
				|  |  | +        prefetch_count = 0
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        def qos(self, prefetch_size=0, prefetch_count=0, apply_global=False):
 | 
	
		
			
				|  |  | +            self.prefetch_count = prefetch_count
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_decrement(self):
 | 
	
		
			
				|  |  | +        consumer = self.MockConsumer()
 | 
	
		
			
				|  |  | +        qos = QoS(consumer, 10, get_logger())
 | 
	
		
			
				|  |  | +        qos.update()
 | 
	
		
			
				|  |  | +        self.assertEqual(int(qos.value), 10)
 | 
	
		
			
				|  |  | +        self.assertEqual(consumer.prefetch_count, 10)
 | 
	
		
			
				|  |  | +        qos.decrement()
 | 
	
		
			
				|  |  | +        self.assertEqual(int(qos.value), 9)
 | 
	
		
			
				|  |  | +        self.assertEqual(consumer.prefetch_count, 9)
 | 
	
		
			
				|  |  | +        qos.decrement_eventually()
 | 
	
		
			
				|  |  | +        self.assertEqual(int(qos.value), 8)
 | 
	
		
			
				|  |  | +        self.assertEqual(consumer.prefetch_count, 9)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +class test_CarrotListener(unittest.TestCase):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def setUp(self):
 | 
	
		
			
				|  |  |          self.ready_queue = FastQueue()
 | 
	
	
		
			
				|  | @@ -232,6 +259,38 @@ class TestCarrotListener(unittest.TestCase):
 | 
	
		
			
				|  |  |          context = catch_warnings(record=True)
 | 
	
		
			
				|  |  |          execute_context(context, with_catch_warnings)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    def test_receive_message_InvalidTaskError(self):
 | 
	
		
			
				|  |  | +        logger = MockLogger()
 | 
	
		
			
				|  |  | +        l = CarrotListener(self.ready_queue, self.eta_schedule, logger,
 | 
	
		
			
				|  |  | +                           send_events=False)
 | 
	
		
			
				|  |  | +        backend = MockBackend()
 | 
	
		
			
				|  |  | +        m = create_message(backend, task=foo_task.name,
 | 
	
		
			
				|  |  | +            args=(1, 2), kwargs="foobarbaz", id=1)
 | 
	
		
			
				|  |  | +        l.event_dispatcher = MockEventDispatcher()
 | 
	
		
			
				|  |  | +        l.control_dispatch = MockControlDispatch()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        l.receive_message(m.decode(), m)
 | 
	
		
			
				|  |  | +        self.assertIn("Invalid task ignored", logger.logged[0])
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_on_decode_error(self):
 | 
	
		
			
				|  |  | +        logger = MockLogger()
 | 
	
		
			
				|  |  | +        l = CarrotListener(self.ready_queue, self.eta_schedule, logger,
 | 
	
		
			
				|  |  | +                           send_events=False)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        class MockMessage(object):
 | 
	
		
			
				|  |  | +            content_type = "application/x-msgpack"
 | 
	
		
			
				|  |  | +            content_encoding = "binary"
 | 
	
		
			
				|  |  | +            body = "foobarbaz"
 | 
	
		
			
				|  |  | +            acked = False
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            def ack(self):
 | 
	
		
			
				|  |  | +                self.acked = True
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        message = MockMessage()
 | 
	
		
			
				|  |  | +        l.on_decode_error(message, KeyError("foo"))
 | 
	
		
			
				|  |  | +        self.assertTrue(message.acked)
 | 
	
		
			
				|  |  | +        self.assertIn("Message decoding error", logger.logged[0])
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      def test_receieve_message(self):
 | 
	
		
			
				|  |  |          l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
 | 
	
		
			
				|  |  |                             send_events=False)
 | 
	
	
		
			
				|  | @@ -331,8 +390,64 @@ class TestCarrotListener(unittest.TestCase):
 | 
	
		
			
				|  |  |          self.assertEqual(task.execute(), 2 * 4 * 8)
 | 
	
		
			
				|  |  |          self.assertRaises(Empty, self.ready_queue.get_nowait)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    def test_start__consume_messages(self):
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        class _QoS(object):
 | 
	
		
			
				|  |  | +            prev = 3
 | 
	
		
			
				|  |  | +            next = 4
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            def update(self):
 | 
	
		
			
				|  |  | +                self.prev = self.next
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        class _Listener(CarrotListener):
 | 
	
		
			
				|  |  | +            iterations = 0
 | 
	
		
			
				|  |  | +            wait_method = None
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            def reset_connection(self):
 | 
	
		
			
				|  |  | +                if self.iterations >= 1:
 | 
	
		
			
				|  |  | +                    raise KeyError("foo")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            def _detect_wait_method(self):
 | 
	
		
			
				|  |  | +                return self.wait_method
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        called_back = [False]
 | 
	
		
			
				|  |  | +        def init_callback(listener):
 | 
	
		
			
				|  |  | +            called_back[0] = True
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        l = _Listener(self.ready_queue, self.eta_schedule, self.logger,
 | 
	
		
			
				|  |  | +                      send_events=False, init_callback=init_callback)
 | 
	
		
			
				|  |  | +        l.qos = _QoS()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        def raises_KeyError(limit=None):
 | 
	
		
			
				|  |  | +            yield True
 | 
	
		
			
				|  |  | +            l.iterations = 1
 | 
	
		
			
				|  |  | +            raise KeyError("foo")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        l.wait_method = raises_KeyError
 | 
	
		
			
				|  |  | +        self.assertRaises(KeyError, l.start)
 | 
	
		
			
				|  |  | +        self.assertTrue(called_back[0])
 | 
	
		
			
				|  |  | +        self.assertEqual(l.iterations, 1)
 | 
	
		
			
				|  |  | +        self.assertEqual(l.qos.prev, l.qos.next)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        l = _Listener(self.ready_queue, self.eta_schedule, self.logger,
 | 
	
		
			
				|  |  | +                      send_events=False, init_callback=init_callback)
 | 
	
		
			
				|  |  | +        l.qos = _QoS()
 | 
	
		
			
				|  |  | +        def raises_socket_error(limit=None):
 | 
	
		
			
				|  |  | +            yield True
 | 
	
		
			
				|  |  | +            l.iterations = 1
 | 
	
		
			
				|  |  | +            raise socket.error("foo")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        l.wait_method = raises_socket_error
 | 
	
		
			
				|  |  | +        self.assertRaises(KeyError, l.start)
 | 
	
		
			
				|  |  | +        self.assertTrue(called_back[0])
 | 
	
		
			
				|  |  | +        self.assertEqual(l.iterations, 1)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -class TestWorkController(unittest.TestCase):
 | 
	
		
			
				|  |  | +class test_WorkController(unittest.TestCase):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def setUp(self):
 | 
	
		
			
				|  |  |          self.worker = WorkController(concurrency=1, loglevel=0)
 |