Browse Source

91% total coverage

Ask Solem 15 years ago
parent
commit
a085889e13

+ 32 - 1
celery/tests/test_task_control.py

@@ -3,6 +3,7 @@ import unittest2 as unittest
 from celery.task import control
 from celery.task.builtins import PingTask
 from celery.utils import gen_unique_id
+from celery.utils.functional import wraps
 
 
 class MockBroadcastPublisher(object):
@@ -18,26 +19,56 @@ class MockBroadcastPublisher(object):
         pass
 
 
+class MockControlReplyConsumer(object):
+
+    def __init__(self, *args, **kwarg):
+        pass
+
+    def collect(self, *args, **kwargs):
+        pass
+
+    def close(self):
+        pass
+
+
 def with_mock_broadcast(fun):
 
+    @wraps(fun)
     def _mocked(*args, **kwargs):
         old_pub = control.BroadcastPublisher
+        old_rep = control.ControlReplyConsumer
         control.BroadcastPublisher = MockBroadcastPublisher
+        control.ControlReplyConsumer = MockControlReplyConsumer
         try:
             return fun(*args, **kwargs)
         finally:
             MockBroadcastPublisher.sent = []
             control.BroadcastPublisher = old_pub
+            control.ControlReplyConsumer = old_rep
     return _mocked
 
 
-class TestBroadcast(unittest.TestCase):
+class test_Broadcast(unittest.TestCase):
+
+    def test_discard_all(self):
+        control.discard_all()
 
     @with_mock_broadcast
     def test_broadcast(self):
         control.broadcast("foobarbaz", arguments=[])
         self.assertIn("foobarbaz", MockBroadcastPublisher.sent)
 
+    @with_mock_broadcast
+    def test_broadcast_limit(self):
+        control.broadcast("foobarbaz1", arguments=[], limit=None,
+                destination=[1, 2, 3])
+        self.assertIn("foobarbaz1", MockBroadcastPublisher.sent)
+
+    @with_mock_broadcast
+    def test_broadcast_validate(self):
+        self.assertRaises(ValueError, control.broadcast, "foobarbaz2",
+                          destination="foo")
+
     @with_mock_broadcast
     def test_rate_limit(self):
         control.rate_limit(PingTask.name, "100/m")

+ 46 - 4
celery/tests/test_utils.py

@@ -7,7 +7,7 @@ from celery import utils
 from celery.tests.utils import sleepdeprived, execute_context
 from celery.tests.utils import mask_modules
 
-class TestChunks(unittest.TestCase):
+class test_chunks(unittest.TestCase):
 
     def test_chunks(self):
 
@@ -27,7 +27,7 @@ class TestChunks(unittest.TestCase):
             [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]])
 
 
-class TestGenUniqueId(unittest.TestCase):
+class test_gen_unique_id(unittest.TestCase):
 
     def test_gen_unique_id_without_ctypes(self):
         old_utils = sys.modules.pop("celery.utils")
@@ -47,7 +47,7 @@ class TestGenUniqueId(unittest.TestCase):
             sys.modules["celery.utils"] = old_utils
 
 
-class TestDivUtils(unittest.TestCase):
+class test_utils(unittest.TestCase):
 
     def test_repeatlast(self):
         items = range(6)
@@ -57,8 +57,50 @@ class TestDivUtils(unittest.TestCase):
         for j in items:
             self.assertEqual(it.next(), i)
 
+    def test_get_full_cls_name(self):
+        Class = type("Fox", (object, ), {"__module__": "quick.brown"})
+        self.assertEqual(utils.get_full_cls_name(Class), "quick.brown.Fox")
+
+    def test_is_iterable(self):
+        for a in "f", ["f"], ("f", ), {"f": "f"}:
+            self.assertTrue(utils.is_iterable(a))
+        for b in object(), 1:
+            self.assertFalse(utils.is_iterable(b))
+
+    def test_padlist(self):
+        self.assertListEqual(utils.padlist(["George", "Costanza", "NYC"], 3),
+                ["George", "Costanza", "NYC"])
+        self.assertListEqual(utils.padlist(["George", "Costanza"], 3),
+                ["George", "Costanza", None])
+        self.assertListEqual(utils.padlist(["George", "Costanza", "NYC"], 4,
+                                           default="Earth"),
+                ["George", "Costanza", "NYC", "Earth"])
+
+    def test_firstmethod_AttributeError(self):
+        self.assertIsNone(utils.firstmethod("foo")([object()]))
+
+    def test_first(self):
+        iterations = [0]
+
+        def predicate(value):
+            iterations[0] += 1
+            if value == 5:
+                return True
+            return False
+
+        self.assertEqual(5, utils.first(predicate, xrange(10)))
+        self.assertEqual(iterations[0], 6)
+
+        iterations[0] = 0
+        self.assertIsNone(utils.first(predicate, xrange(10, 20)))
+        self.assertEqual(iterations[0], 10)
+
+    def test_get_cls_by_name__instance_returns_instance(self):
+        instance = object()
+        self.assertIs(utils.get_cls_by_name(instance), instance)
+
 
-class TestRetryOverTime(unittest.TestCase):
+class test_retry_over_time(unittest.TestCase):
 
     def test_returns_retval_on_success(self):
 

+ 42 - 2
celery/tests/test_worker.py

@@ -88,6 +88,8 @@ class MockBackend(object):
 
 
 class MockPool(object):
+    _terminated = False
+    _stopped = False
 
     def __init__(self, *args, **kwargs):
         self.raise_regular = kwargs.get("raise_regular", False)
@@ -103,9 +105,13 @@ class MockPool(object):
         pass
 
     def stop(self):
-        pass
+        self._stopped = True
         return True
 
+    def terminate(self):
+        self._terminated = True
+        self.stop()
+
 
 class MockController(object):
 
@@ -453,6 +459,14 @@ class test_WorkController(unittest.TestCase):
         self.worker = WorkController(concurrency=1, loglevel=0)
         self.worker.logger = MockLogger()
 
+    def test_with_rate_limits_disabled(self):
+        conf.DISABLE_RATE_LIMITS = True
+        try:
+            worker = WorkController(concurrency=1, loglevel=0)
+            self.assertIsInstance(worker.ready_queue, FastQueue)
+        finally:
+            conf.DISABLE_RATE_LIMITS = False
+
     def test_attrs(self):
         worker = self.worker
         self.assertIsInstance(worker.eta_schedule, Scheduler)
@@ -462,6 +476,12 @@ class test_WorkController(unittest.TestCase):
         self.assertTrue(worker.mediator)
         self.assertTrue(worker.components)
 
+    def test_with_embedded_clockservice(self):
+        worker = WorkController(concurrency=1, loglevel=0,
+                                embed_clockservice=True)
+        self.assertTrue(worker.clockservice)
+        self.assertIn(worker.clockservice, worker.components)
+
     def test_process_task(self):
         worker = self.worker
         worker.pool = MockPool()
@@ -492,7 +512,7 @@ class test_WorkController(unittest.TestCase):
         worker.process_task(task)
         worker.pool.stop()
 
-    def test_start_stop(self):
+    def test_start__stop(self):
         worker = self.worker
         w1 = {"started": False}
         w2 = {"started": False}
@@ -508,3 +528,23 @@ class test_WorkController(unittest.TestCase):
         worker.stop()
         for component in worker.components:
             self.assertTrue(component._stopped)
+
+    def test_start__terminate(self):
+        worker = self.worker
+        w1 = {"started": False}
+        w2 = {"started": False}
+        w3 = {"started": False}
+        w4 = {"started": False}
+        worker.components = [MockController(w1), MockController(w2),
+                             MockController(w3), MockController(w4),
+                             MockPool()]
+
+        worker.start()
+        for w in (w1, w2, w3, w4):
+            self.assertTrue(w["started"])
+        self.assertTrue(worker._running, len(worker.components))
+        self.assertEqual(worker._state, RUN)
+        worker.terminate()
+        for component in worker.components:
+            self.assertTrue(component._stopped)
+        self.assertTrue(worker.components[4]._terminated)

+ 5 - 8
celery/worker/__init__.py

@@ -191,14 +191,11 @@ class WorkController(object):
         """Starts the workers main loop."""
         self._state = RUN
 
-        try:
-            for i, component in enumerate(self.components):
-                self.logger.debug("Starting thread %s..." % \
-                        component.__class__.__name__)
-                self._running = i + 1
-                component.start()
-        finally:
-            self.stop()
+        for i, component in enumerate(self.components):
+            self.logger.debug("Starting thread %s..." % (
+                                    component.__class__.__name__))
+            self._running = i + 1
+            component.start()
 
     def process_task(self, wrapper):
         """Process task by sending it to the pool of workers."""

+ 2 - 2
celery/worker/listener.py

@@ -94,8 +94,8 @@ from celery.messaging import get_consumer_set, BroadcastConsumer
 from celery.exceptions import NotRegistered
 from celery.datastructures import SharedCounter
 
-RUN = 0x0
-CLOSE = 0x1
+RUN = 0x1
+CLOSE = 0x2
 
 
 class QoS(object):