Browse Source

100% Coverage for celery.worker.autoscale

Ask Solem 14 years ago
parent
commit
1c99a986dd

+ 114 - 1
celery/tests/test_worker.py

@@ -173,7 +173,7 @@ class test_QoS(unittest.TestCase):
         def qos(self, prefetch_size=0, prefetch_count=0, apply_global=False):
             self.prefetch_count = prefetch_count
 
-    def test_decrement(self):
+    def test_increment_decrement(self):
         consumer = self.MockConsumer()
         qos = QoS(consumer, 10, app_or_default().log.get_default_logger())
         qos.update()
@@ -186,6 +186,13 @@ class test_QoS(unittest.TestCase):
         self.assertEqual(int(qos.value), 8)
         self.assertEqual(consumer.prefetch_count, 9)
 
+        # Does not decrement 0 value
+        qos.value._value = 0
+        qos.decrement()
+        self.assertEqual(int(qos.value), 0)
+        qos.increment()
+        self.assertEqual(int(qos.value), 0)
+
 
 class test_Consumer(unittest.TestCase):
 
@@ -198,6 +205,18 @@ class test_Consumer(unittest.TestCase):
     def tearDown(self):
         self.eta_schedule.stop()
 
+    def test_info(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                           send_events=False)
+        l.qos = QoS(l.task_consumer, 10, l.logger)
+        info = l.info
+        self.assertEqual(info["prefetch_count"], 10)
+        self.assertFalse(info["broker"])
+
+        l.connection = app_or_default().broker_connection()
+        info = l.info
+        self.assertTrue(info["broker"])
+
     def test_connection(self):
         l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                            send_events=False)
@@ -205,6 +224,12 @@ class test_Consumer(unittest.TestCase):
         l.reset_connection()
         self.assertIsInstance(l.connection, BrokerConnection)
 
+        l._state = RUN
+        l.event_dispatcher = None
+        l.stop_consumers(close=False)
+        self.assertTrue(l.connection)
+
+        l._state = RUN
         l.stop_consumers()
         self.assertIsNone(l.connection)
         self.assertIsNone(l.task_consumer)
@@ -323,6 +348,94 @@ class test_Consumer(unittest.TestCase):
         self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
         self.assertTrue(self.eta_schedule.empty())
 
+    def test_start_connection_error(self):
+
+        class MockConsumer(MainConsumer):
+            iterations = 0
+
+            def consume_messages(self):
+                if not self.iterations:
+                    self.iterations = 1
+                    raise KeyError("foo")
+                raise SyntaxError("bar")
+
+        l = MockConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                             send_events=False)
+        l.connection_errors = (KeyError, )
+        self.assertRaises(SyntaxError, l.start)
+        l.heart.stop()
+
+    def test_consume_messages(self):
+        app = app_or_default()
+
+        class Connection(app.broker_connection().__class__):
+            obj = None
+
+            def drain_events(self, **kwargs):
+                self.obj.connection = None
+
+        class Consumer(object):
+            consuming = False
+            prefetch_count = 0
+
+            def consume(self):
+                self.consuming = True
+
+            def qos(self, prefetch_size=0, prefetch_count=0,
+                            apply_global=False):
+                self.prefetch_count = prefetch_count
+
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                             send_events=False)
+        l.connection = Connection()
+        l.connection.obj = l
+        l.task_consumer = Consumer()
+        l.broadcast_consumer = Consumer()
+        l.qos = QoS(l.task_consumer, 10, l.logger)
+
+        l.consume_messages()
+        l.consume_messages()
+        self.assertTrue(l.task_consumer.consuming)
+        self.assertTrue(l.broadcast_consumer.consuming)
+        self.assertEqual(l.task_consumer.prefetch_count, 10)
+
+        l.qos.decrement()
+        l.consume_messages()
+        self.assertEqual(l.task_consumer.prefetch_count, 9)
+
+    def test_maybe_conn_error(self):
+
+        def raises(error):
+
+            def fun():
+                raise error
+
+            return fun
+
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                             send_events=False)
+        l.connection_errors = (KeyError, )
+        l.channel_errors = (SyntaxError, )
+        l.maybe_conn_error(raises(AttributeError("foo")))
+        l.maybe_conn_error(raises(KeyError("foo")))
+        l.maybe_conn_error(raises(SyntaxError("foo")))
+        self.assertRaises(IndexError, l.maybe_conn_error,
+                raises(IndexError("foo")))
+
+
+    def test_apply_eta_task(self):
+        from celery.worker import state
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                             send_events=False)
+        l.qos = QoS(None, 10, l.logger)
+
+        task = object()
+        qos = l.qos.next
+        l.apply_eta_task(task)
+        self.assertIn(task, state.reserved_requests)
+        self.assertEqual(l.qos.next, qos - 1)
+        self.assertIs(self.ready_queue.get_nowait(), task)
+
     def test_receieve_message_eta_isoformat(self):
 
         class MockConsumer(object):

+ 104 - 0
celery/tests/test_worker_autoscale.py

@@ -0,0 +1,104 @@
+import logging
+
+from time import time
+
+from celery.concurrency.base import BasePool
+from celery.worker import state
+from celery.worker import autoscale
+
+from celery.tests.utils import unittest, sleepdeprived
+
+logger = logging.getLogger("celery.tests.autoscale")
+
+
+class Object(object):
+    pass
+
+
+class MockPool(BasePool):
+    shrink_raises_exception = False
+
+    def __init__(self, *args, **kwargs):
+        super(MockPool, self).__init__(*args, **kwargs)
+        self._pool = Object()
+        self._pool._processes = self.limit
+
+    def grow(self, n=1):
+        self._pool._processes += n
+
+    def shrink(self, n=1):
+        if self.shrink_raises_exception:
+            raise KeyError("foo")
+        self._pool._processes -= n
+
+    @property
+    def current(self):
+        return self._pool._processes
+
+
+class test_Autoscaler(unittest.TestCase):
+
+    def setUp(self):
+        self.pool = MockPool(3)
+
+    def test_stop(self):
+
+        class Scaler(autoscale.Autoscaler):
+            alive = True
+            joined = False
+
+            def isAlive(self):
+                return self.alive
+
+            def join(self, timeout=None):
+                self.joined = True
+
+        x = Scaler(self.pool, 10, 3, logger=logger)
+        x._stopped.set()
+        x.stop()
+        self.assertTrue(x.joined)
+        x.joined = False
+        x.alive = False
+        x.stop()
+        self.assertFalse(x.joined)
+
+    @sleepdeprived(autoscale)
+    def test_scale(self):
+        x = autoscale.Autoscaler(self.pool, 10, 3, logger=logger)
+        x.scale()
+        self.assertEqual(x.pool.current, 3)
+        for i in range(20):
+            state.reserved_requests.add(i)
+        x.scale()
+        x.scale()
+        self.assertEqual(x.pool.current, 10)
+        state.reserved_requests.clear()
+        x.scale()
+        self.assertEqual(x.pool.current, 10)
+        x._last_action = time() - 10000
+        x.scale()
+        self.assertEqual(x.pool.current, 3)
+
+    def test_run(self):
+
+        class Scaler(autoscale.Autoscaler):
+            scale_called = False
+
+            def scale(self):
+                self.scale_called = True
+                self._shutdown.set()
+
+        x = Scaler(self.pool, 10, 3, logger=logger)
+        x.run()
+        self.assertTrue(x._shutdown.isSet())
+        self.assertTrue(x._stopped.isSet())
+        self.assertTrue(x.scale_called)
+
+
+
+    def test_shrink_raises_exception(self):
+        x = autoscale.Autoscaler(self.pool, 10, 3, logger=logger)
+        x.scale_up(3)
+        x._last_action = time() - 10000
+        x.pool.shrink_raises_exception = True
+        x.scale_down(1)

+ 1 - 5
celery/tests/test_worker_heartbeat.py

@@ -58,7 +58,6 @@ class TestHeart(unittest.TestCase):
         heart.run()
         self.assertEqual(heart._state, "RUN")
         self.assertTrue(heart._shutdown.isSet())
-        self.assertTrue(heart._stopped.isSet())
 
     def test_run(self):
         eventer = MockDispatcher()
@@ -71,8 +70,6 @@ class TestHeart(unittest.TestCase):
         self.assertIn("worker-heartbeat", eventer.sent)
         self.assertIn("worker-offline", eventer.sent)
 
-        self.assertTrue(heart._stopped.isSet())
-
         heart.stop()
         heart.stop()
         self.assertEqual(heart._state, "CLOSE")
@@ -82,9 +79,8 @@ class TestHeart(unittest.TestCase):
         for i in range(10):
             heart.run()
 
-    def test_run_stopped_is_set_even_if_send_breaks(self):
+    def test_run_exception(self):
         eventer = MockDispatcherRaising()
         heart = Heart(eventer, interval=1)
         heart._shutdown.set()
         self.assertRaises(Exception, heart.run)
-        self.assertTrue(heart._stopped.isSet())

+ 13 - 10
celery/tests/utils.py

@@ -8,6 +8,7 @@ except AttributeError:
 
 import os
 import sys
+import time
 try:
     import __builtin__ as builtins
 except ImportError:    # py3k
@@ -122,18 +123,20 @@ def with_environ(env_name, env_value):
         return _patch_environ
     return _envpatched
 
+def sleepdeprived(module=time):
 
-def sleepdeprived(fun):
+    def _sleepdeprived(fun):
 
-    @wraps(fun)
-    def _sleepdeprived(*args, **kwargs):
-        import time
-        old_sleep = time.sleep
-        time.sleep = noop
-        try:
-            return fun(*args, **kwargs)
-        finally:
-            time.sleep = old_sleep
+        @wraps(fun)
+        def __sleepdeprived(*args, **kwargs):
+            old_sleep = module.sleep
+            module.sleep = noop
+            try:
+                return fun(*args, **kwargs)
+            finally:
+                module.sleep = old_sleep
+
+        return __sleepdeprived
 
     return _sleepdeprived
 

+ 4 - 3
celery/worker/autoscale.py

@@ -47,8 +47,8 @@ class Autoscaler(threading.Thread):
             try:
                 self.pool.shrink(n)
             except Exception, exc:
-                traceback.print_stack()
-                self.logger.error("Autoscaler: scale_down: %r" % (exc, ),
+                self.logger.error("Autoscaler: scale_down: %r\n%r" % (
+                                    exc, traceback.format_stack()),
                                   exc_info=sys.exc_info())
 
     def run(self):
@@ -59,7 +59,8 @@ class Autoscaler(threading.Thread):
     def stop(self):
         self._shutdown.set()
         self._stopped.wait()
-        self.join(1e100)
+        if self.isAlive():
+            self.join(1e100)
 
     @property
     def qty(self):

+ 0 - 1
celery/worker/consumer.py

@@ -197,7 +197,6 @@ class Consumer(object):
     def __init__(self, ready_queue, eta_schedule, logger,
             init_callback=noop, send_events=False, hostname=None,
             initial_prefetch_count=2, pool=None, queues=None, app=None):
-
         self.app = app_or_default(app)
         self.connection = None
         self.task_consumer = None

+ 2 - 6
celery/worker/heartbeat.py

@@ -20,7 +20,6 @@ class Heart(threading.Thread):
         self.eventer = eventer
         self.bpm = interval and interval / 60.0 or self.bpm
         self._shutdown = threading.Event()
-        self._stopped = threading.Event()
         self.setDaemon(True)
         self.setName(self.__class__.__name__)
         self._state = None
@@ -40,7 +39,7 @@ class Heart(threading.Thread):
         while 1:
             try:
                 now = time()
-            except TypeError:
+            except TypeError:  # pragma: no cover
                 # we lost the race at interpreter shutdown,
                 # so time has been collected by gc.
                 return
@@ -52,10 +51,7 @@ class Heart(threading.Thread):
                 break
             sleep(1)
 
-        try:
-            dispatch("worker-offline")
-        finally:
-            self._stopped.set()
+        dispatch("worker-offline")
 
     def stop(self):
         """Gracefully shutdown the thread."""