123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225 |
- from __future__ import absolute_import, unicode_literals
- import sys
- from case import Mock, mock, patch
- from celery.concurrency.base import BasePool
- from celery.five import monotonic
- from celery.utils.objects import Bunch
- from celery.worker import autoscale, state
- class MockPool(BasePool):
- shrink_raises_exception = False
- shrink_raises_ValueError = False
- def __init__(self, *args, **kwargs):
- super(MockPool, self).__init__(*args, **kwargs)
- self._pool = Bunch(_processes=self.limit)
- def grow(self, n=1):
- self._pool._processes += n
- def shrink(self, n=1):
- if self.shrink_raises_exception:
- raise KeyError('foo')
- if self.shrink_raises_ValueError:
- raise ValueError('foo')
- self._pool._processes -= n
- @property
- def num_processes(self):
- return self._pool._processes
- class test_WorkerComponent:
- def test_register_with_event_loop(self):
- parent = Mock(name='parent')
- parent.autoscale = True
- parent.consumer.on_task_message = set()
- w = autoscale.WorkerComponent(parent)
- assert parent.autoscaler is None
- assert w.enabled
- hub = Mock(name='hub')
- w.create(parent)
- w.register_with_event_loop(parent, hub)
- assert (parent.autoscaler.maybe_scale in
- parent.consumer.on_task_message)
- hub.call_repeatedly.assert_called_with(
- parent.autoscaler.keepalive, parent.autoscaler.maybe_scale,
- )
- parent.hub = hub
- hub.on_init = []
- w.instantiate = Mock()
- w.register_with_event_loop(parent, Mock(name='loop'))
- assert parent.consumer.on_task_message
- def test_info_without_event_loop(self):
- parent = Mock(name='parent')
- parent.autoscale = True
- parent.max_concurrency = '10'
- parent.min_concurrency = '2'
- parent.use_eventloop = False
- w = autoscale.WorkerComponent(parent)
- w.create(parent)
- info = w.info(parent)
- assert 'autoscaler' in info
- assert parent.autoscaler_cls().info.called
- class test_Autoscaler:
- def setup(self):
- self.pool = MockPool(3)
- def test_stop(self):
- class Scaler(autoscale.Autoscaler):
- alive = True
- joined = False
- def is_alive(self):
- return self.alive
- def join(self, timeout=None):
- self.joined = True
- worker = Mock(name='worker')
- x = Scaler(self.pool, 10, 3, worker=worker)
- x._is_stopped.set()
- x.stop()
- assert x.joined
- x.joined = False
- x.alive = False
- x.stop()
- assert not x.joined
- @mock.sleepdeprived(module=autoscale)
- def test_body(self):
- worker = Mock(name='worker')
- x = autoscale.Autoscaler(self.pool, 10, 3, worker=worker)
- x.body()
- assert x.pool.num_processes == 3
- _keep = [Mock(name='req{0}'.format(i)) for i in range(20)]
- [state.task_reserved(m) for m in _keep]
- x.body()
- x.body()
- assert x.pool.num_processes == 10
- worker.consumer._update_prefetch_count.assert_called()
- state.reserved_requests.clear()
- x.body()
- assert x.pool.num_processes == 10
- x._last_scale_up = monotonic() - 10000
- x.body()
- assert x.pool.num_processes == 3
- worker.consumer._update_prefetch_count.assert_called()
- def test_run(self):
- class Scaler(autoscale.Autoscaler):
- scale_called = False
- def body(self):
- self.scale_called = True
- self._is_shutdown.set()
- worker = Mock(name='worker')
- x = Scaler(self.pool, 10, 3, worker=worker)
- x.run()
- assert x._is_shutdown.isSet()
- assert x._is_stopped.isSet()
- assert x.scale_called
- def test_shrink_raises_exception(self):
- worker = Mock(name='worker')
- x = autoscale.Autoscaler(self.pool, 10, 3, worker=worker)
- x.scale_up(3)
- x.pool.shrink_raises_exception = True
- x._shrink(1)
- @patch('celery.worker.autoscale.debug')
- def test_shrink_raises_ValueError(self, debug):
- worker = Mock(name='worker')
- x = autoscale.Autoscaler(self.pool, 10, 3, worker=worker)
- x.scale_up(3)
- x._last_scale_up = monotonic() - 10000
- x.pool.shrink_raises_ValueError = True
- x.scale_down(1)
- assert debug.call_count
- def test_update_and_force(self):
- worker = Mock(name='worker')
- x = autoscale.Autoscaler(self.pool, 10, 3, worker=worker)
- assert x.processes == 3
- x.force_scale_up(5)
- assert x.processes == 8
- x.update(5, None)
- assert x.processes == 5
- x.force_scale_down(3)
- assert x.processes == 2
- x.update(None, 3)
- assert x.processes == 3
- x.force_scale_down(1000)
- assert x.min_concurrency == 0
- assert x.processes == 0
- x.force_scale_up(1000)
- x.min_concurrency = 1
- x.force_scale_down(1)
- x.update(max=300, min=10)
- x.update(max=300, min=2)
- x.update(max=None, min=None)
- def test_info(self):
- worker = Mock(name='worker')
- x = autoscale.Autoscaler(self.pool, 10, 3, worker=worker)
- info = x.info()
- assert info['max'] == 10
- assert info['min'] == 3
- assert info['current'] == 3
- @patch('os._exit')
- def test_thread_crash(self, _exit):
- class _Autoscaler(autoscale.Autoscaler):
- def body(self):
- self._is_shutdown.set()
- raise OSError('foo')
- worker = Mock(name='worker')
- x = _Autoscaler(self.pool, 10, 3, worker=worker)
- stderr = Mock()
- p, sys.stderr = sys.stderr, stderr
- try:
- x.run()
- finally:
- sys.stderr = p
- _exit.assert_called_with(1)
- stderr.write.assert_called()
- @mock.sleepdeprived(module=autoscale)
- def test_no_negative_scale(self):
- total_num_processes = []
- worker = Mock(name='worker')
- x = autoscale.Autoscaler(self.pool, 10, 3, worker=worker)
- x.body() # the body func scales up or down
- _keep = [Mock(name='req{0}'.format(i)) for i in range(35)]
- for req in _keep:
- state.task_reserved(req)
- x.body()
- total_num_processes.append(self.pool.num_processes)
- for req in _keep:
- state.task_ready(req)
- x.body()
- total_num_processes.append(self.pool.num_processes)
- assert all(x.min_concurrency <= i <= x.max_concurrency
- for i in total_num_processes)
|