|
@@ -1,24 +1,26 @@
|
|
|
+import socket
|
|
|
import unittest2 as unittest
|
|
|
-from Queue import Empty
|
|
|
+
|
|
|
from datetime import datetime, timedelta
|
|
|
from multiprocessing import get_logger
|
|
|
+from Queue import Empty
|
|
|
|
|
|
-from carrot.connection import BrokerConnection
|
|
|
from carrot.backends.base import BaseMessage
|
|
|
+from carrot.connection import BrokerConnection
|
|
|
|
|
|
from celery import conf
|
|
|
+from celery.decorators import task as task_dec
|
|
|
+from celery.decorators import periodic_task as periodic_task_dec
|
|
|
+from celery.serialization import pickle
|
|
|
from celery.utils import gen_unique_id
|
|
|
from celery.worker import WorkController
|
|
|
-from celery.worker.job import TaskRequest
|
|
|
from celery.worker.buckets import FastQueue
|
|
|
+from celery.worker.job import TaskRequest
|
|
|
from celery.worker.listener import CarrotListener, QoS, RUN
|
|
|
from celery.worker.scheduler import Scheduler
|
|
|
-from celery.decorators import task as task_dec
|
|
|
-from celery.decorators import periodic_task as periodic_task_dec
|
|
|
-from celery.serialization import pickle
|
|
|
|
|
|
-from celery.tests.utils import execute_context
|
|
|
from celery.tests.compat import catch_warnings
|
|
|
+from celery.tests.utils import execute_context
|
|
|
|
|
|
|
|
|
class PlaceHolder(object):
|
|
@@ -62,17 +64,20 @@ def foo_periodic_task():
|
|
|
|
|
|
class MockLogger(object):
|
|
|
|
|
|
- def critical(self, *args, **kwargs):
|
|
|
- pass
|
|
|
+ def __init__(self):
|
|
|
+ self.logged = []
|
|
|
|
|
|
- def info(self, *args, **kwargs):
|
|
|
- pass
|
|
|
+ def critical(self, msg, *args, **kwargs):
|
|
|
+ self.logged.append(msg)
|
|
|
|
|
|
- def error(self, *args, **kwargs):
|
|
|
- pass
|
|
|
+ def info(self, msg, *args, **kwargs):
|
|
|
+ self.logged.append(msg)
|
|
|
|
|
|
- def debug(self, *args, **kwargs):
|
|
|
- pass
|
|
|
+ def error(self, msg, *args, **kwargs):
|
|
|
+ self.logged.append(msg)
|
|
|
+
|
|
|
+ def debug(self, msg, *args, **kwargs):
|
|
|
+ self.logged.append(msg)
|
|
|
|
|
|
|
|
|
class MockBackend(object):
|
|
@@ -123,7 +128,29 @@ def create_message(backend, **data):
|
|
|
content_encoding="binary")
|
|
|
|
|
|
|
|
|
-class TestCarrotListener(unittest.TestCase):
|
|
|
+class test_QoS(unittest.TestCase):
|
|
|
+
|
|
|
+ class MockConsumer(object):
|
|
|
+ prefetch_count = 0
|
|
|
+
|
|
|
+ def qos(self, prefetch_size=0, prefetch_count=0, apply_global=False):
|
|
|
+ self.prefetch_count = prefetch_count
|
|
|
+
|
|
|
+ def test_decrement(self):
|
|
|
+ consumer = self.MockConsumer()
|
|
|
+ qos = QoS(consumer, 10, get_logger())
|
|
|
+ qos.update()
|
|
|
+ self.assertEqual(int(qos.value), 10)
|
|
|
+ self.assertEqual(consumer.prefetch_count, 10)
|
|
|
+ qos.decrement()
|
|
|
+ self.assertEqual(int(qos.value), 9)
|
|
|
+ self.assertEqual(consumer.prefetch_count, 9)
|
|
|
+ qos.decrement_eventually()
|
|
|
+ self.assertEqual(int(qos.value), 8)
|
|
|
+ self.assertEqual(consumer.prefetch_count, 9)
|
|
|
+
|
|
|
+
|
|
|
+class test_CarrotListener(unittest.TestCase):
|
|
|
|
|
|
def setUp(self):
|
|
|
self.ready_queue = FastQueue()
|
|
@@ -232,6 +259,38 @@ class TestCarrotListener(unittest.TestCase):
|
|
|
context = catch_warnings(record=True)
|
|
|
execute_context(context, with_catch_warnings)
|
|
|
|
|
|
+ def test_receive_message_InvalidTaskError(self):
|
|
|
+ logger = MockLogger()
|
|
|
+ l = CarrotListener(self.ready_queue, self.eta_schedule, logger,
|
|
|
+ send_events=False)
|
|
|
+ backend = MockBackend()
|
|
|
+ m = create_message(backend, task=foo_task.name,
|
|
|
+ args=(1, 2), kwargs="foobarbaz", id=1)
|
|
|
+ l.event_dispatcher = MockEventDispatcher()
|
|
|
+ l.control_dispatch = MockControlDispatch()
|
|
|
+
|
|
|
+ l.receive_message(m.decode(), m)
|
|
|
+ self.assertIn("Invalid task ignored", logger.logged[0])
|
|
|
+
|
|
|
+ def test_on_decode_error(self):
|
|
|
+ logger = MockLogger()
|
|
|
+ l = CarrotListener(self.ready_queue, self.eta_schedule, logger,
|
|
|
+ send_events=False)
|
|
|
+
|
|
|
+ class MockMessage(object):
|
|
|
+ content_type = "application/x-msgpack"
|
|
|
+ content_encoding = "binary"
|
|
|
+ body = "foobarbaz"
|
|
|
+ acked = False
|
|
|
+
|
|
|
+ def ack(self):
|
|
|
+ self.acked = True
|
|
|
+
|
|
|
+ message = MockMessage()
|
|
|
+ l.on_decode_error(message, KeyError("foo"))
|
|
|
+ self.assertTrue(message.acked)
|
|
|
+ self.assertIn("Message decoding error", logger.logged[0])
|
|
|
+
|
|
|
def test_receieve_message(self):
|
|
|
l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
|
|
|
send_events=False)
|
|
@@ -331,8 +390,64 @@ class TestCarrotListener(unittest.TestCase):
|
|
|
self.assertEqual(task.execute(), 2 * 4 * 8)
|
|
|
self.assertRaises(Empty, self.ready_queue.get_nowait)
|
|
|
|
|
|
+ def test_start__consume_messages(self):
|
|
|
+
|
|
|
+ class _QoS(object):
|
|
|
+ prev = 3
|
|
|
+ next = 4
|
|
|
+
|
|
|
+ def update(self):
|
|
|
+ self.prev = self.next
|
|
|
+
|
|
|
+ class _Listener(CarrotListener):
|
|
|
+ iterations = 0
|
|
|
+ wait_method = None
|
|
|
+
|
|
|
+ def reset_connection(self):
|
|
|
+ if self.iterations >= 1:
|
|
|
+ raise KeyError("foo")
|
|
|
+
|
|
|
+ def _detect_wait_method(self):
|
|
|
+ return self.wait_method
|
|
|
+
|
|
|
+ called_back = [False]
|
|
|
+ def init_callback(listener):
|
|
|
+ called_back[0] = True
|
|
|
+
|
|
|
+
|
|
|
+ l = _Listener(self.ready_queue, self.eta_schedule, self.logger,
|
|
|
+ send_events=False, init_callback=init_callback)
|
|
|
+ l.qos = _QoS()
|
|
|
+
|
|
|
+ def raises_KeyError(limit=None):
|
|
|
+ yield True
|
|
|
+ l.iterations = 1
|
|
|
+ raise KeyError("foo")
|
|
|
+
|
|
|
+ l.wait_method = raises_KeyError
|
|
|
+ self.assertRaises(KeyError, l.start)
|
|
|
+ self.assertTrue(called_back[0])
|
|
|
+ self.assertEqual(l.iterations, 1)
|
|
|
+ self.assertEqual(l.qos.prev, l.qos.next)
|
|
|
+
|
|
|
+ l = _Listener(self.ready_queue, self.eta_schedule, self.logger,
|
|
|
+ send_events=False, init_callback=init_callback)
|
|
|
+ l.qos = _QoS()
|
|
|
+ def raises_socket_error(limit=None):
|
|
|
+ yield True
|
|
|
+ l.iterations = 1
|
|
|
+ raise socket.error("foo")
|
|
|
+
|
|
|
+ l.wait_method = raises_socket_error
|
|
|
+ self.assertRaises(KeyError, l.start)
|
|
|
+ self.assertTrue(called_back[0])
|
|
|
+ self.assertEqual(l.iterations, 1)
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
|
|
|
-class TestWorkController(unittest.TestCase):
|
|
|
+class test_WorkController(unittest.TestCase):
|
|
|
|
|
|
def setUp(self):
|
|
|
self.worker = WorkController(concurrency=1, loglevel=0)
|