test_autoscale.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. import sys
  2. from time import monotonic
  3. from case import Mock, mock, patch
  4. from celery.concurrency.base import BasePool
  5. from celery.worker import state
  6. from celery.worker import autoscale
  7. from celery.utils.objects import Bunch
  8. class MockPool(BasePool):
  9. shrink_raises_exception = False
  10. shrink_raises_ValueError = False
  11. def __init__(self, *args, **kwargs):
  12. super(MockPool, self).__init__(*args, **kwargs)
  13. self._pool = Bunch(_processes=self.limit)
  14. def grow(self, n=1):
  15. self._pool._processes += n
  16. def shrink(self, n=1):
  17. if self.shrink_raises_exception:
  18. raise KeyError('foo')
  19. if self.shrink_raises_ValueError:
  20. raise ValueError('foo')
  21. self._pool._processes -= n
  22. @property
  23. def num_processes(self):
  24. return self._pool._processes
  25. class test_WorkerComponent:
  26. def test_register_with_event_loop(self):
  27. parent = Mock(name='parent')
  28. parent.autoscale = True
  29. parent.consumer.on_task_message = set()
  30. w = autoscale.WorkerComponent(parent)
  31. assert parent.autoscaler is None
  32. assert w.enabled
  33. hub = Mock(name='hub')
  34. w.create(parent)
  35. w.register_with_event_loop(parent, hub)
  36. assert (parent.autoscaler.maybe_scale in
  37. parent.consumer.on_task_message)
  38. hub.call_repeatedly.assert_called_with(
  39. parent.autoscaler.keepalive, parent.autoscaler.maybe_scale,
  40. )
  41. parent.hub = hub
  42. hub.on_init = []
  43. w.instantiate = Mock()
  44. w.register_with_event_loop(parent, Mock(name='loop'))
  45. assert parent.consumer.on_task_message
  46. class test_Autoscaler:
  47. def setup(self):
  48. self.pool = MockPool(3)
  49. def test_stop(self):
  50. class Scaler(autoscale.Autoscaler):
  51. alive = True
  52. joined = False
  53. def is_alive(self):
  54. return self.alive
  55. def join(self, timeout=None):
  56. self.joined = True
  57. worker = Mock(name='worker')
  58. x = Scaler(self.pool, 10, 3, worker=worker)
  59. x._is_stopped.set()
  60. x.stop()
  61. assert x.joined
  62. x.joined = False
  63. x.alive = False
  64. x.stop()
  65. assert not x.joined
  66. @mock.sleepdeprived(module=autoscale)
  67. def test_body(self):
  68. worker = Mock(name='worker')
  69. x = autoscale.Autoscaler(self.pool, 10, 3, worker=worker)
  70. x.body()
  71. assert x.pool.num_processes == 3
  72. _keep = [Mock(name='req{0}'.format(i)) for i in range(20)]
  73. [state.task_reserved(m) for m in _keep]
  74. x.body()
  75. x.body()
  76. assert x.pool.num_processes == 10
  77. worker.consumer._update_prefetch_count.assert_called()
  78. state.reserved_requests.clear()
  79. x.body()
  80. assert x.pool.num_processes == 10
  81. x._last_scale_up = monotonic() - 10000
  82. x.body()
  83. assert x.pool.num_processes == 3
  84. worker.consumer._update_prefetch_count.assert_called()
  85. def test_run(self):
  86. class Scaler(autoscale.Autoscaler):
  87. scale_called = False
  88. def body(self):
  89. self.scale_called = True
  90. self._is_shutdown.set()
  91. worker = Mock(name='worker')
  92. x = Scaler(self.pool, 10, 3, worker=worker)
  93. x.run()
  94. assert x._is_shutdown.isSet()
  95. assert x._is_stopped.isSet()
  96. assert x.scale_called
  97. def test_shrink_raises_exception(self):
  98. worker = Mock(name='worker')
  99. x = autoscale.Autoscaler(self.pool, 10, 3, worker=worker)
  100. x.scale_up(3)
  101. x.pool.shrink_raises_exception = True
  102. x._shrink(1)
  103. @patch('celery.worker.autoscale.debug')
  104. def test_shrink_raises_ValueError(self, debug):
  105. worker = Mock(name='worker')
  106. x = autoscale.Autoscaler(self.pool, 10, 3, worker=worker)
  107. x.scale_up(3)
  108. x._last_scale_up = monotonic() - 10000
  109. x.pool.shrink_raises_ValueError = True
  110. x.scale_down(1)
  111. assert debug.call_count
  112. def test_update_and_force(self):
  113. worker = Mock(name='worker')
  114. x = autoscale.Autoscaler(self.pool, 10, 3, worker=worker)
  115. assert x.processes == 3
  116. x.force_scale_up(5)
  117. assert x.processes == 8
  118. x.update(5, None)
  119. assert x.processes == 5
  120. x.force_scale_down(3)
  121. assert x.processes == 2
  122. x.update(None, 3)
  123. assert x.processes == 3
  124. x.force_scale_down(1000)
  125. assert x.min_concurrency == 0
  126. assert x.processes == 0
  127. x.force_scale_up(1000)
  128. x.min_concurrency = 1
  129. x.force_scale_down(1)
  130. x.update(max=300, min=10)
  131. x.update(max=300, min=2)
  132. x.update(max=None, min=None)
  133. def test_info(self):
  134. worker = Mock(name='worker')
  135. x = autoscale.Autoscaler(self.pool, 10, 3, worker=worker)
  136. info = x.info()
  137. assert info['max'] == 10
  138. assert info['min'] == 3
  139. assert info['current'] == 3
  140. @patch('os._exit')
  141. def test_thread_crash(self, _exit):
  142. class _Autoscaler(autoscale.Autoscaler):
  143. def body(self):
  144. self._is_shutdown.set()
  145. raise OSError('foo')
  146. worker = Mock(name='worker')
  147. x = _Autoscaler(self.pool, 10, 3, worker=worker)
  148. stderr = Mock()
  149. p, sys.stderr = sys.stderr, stderr
  150. try:
  151. x.run()
  152. finally:
  153. sys.stderr = p
  154. _exit.assert_called_with(1)
  155. stderr.write.assert_called()
  156. @mock.sleepdeprived(module=autoscale)
  157. def test_no_negative_scale(self):
  158. total_num_processes = []
  159. worker = Mock(name='worker')
  160. x = autoscale.Autoscaler(self.pool, 10, 3, worker=worker)
  161. x.body() # the body func scales up or down
  162. _keep = [Mock(name='req{0}'.format(i)) for i in range(35)]
  163. for req in _keep:
  164. state.task_reserved(req)
  165. x.body()
  166. total_num_processes.append(self.pool.num_processes)
  167. for req in _keep:
  168. state.task_ready(req)
  169. x.body()
  170. total_num_processes.append(self.pool.num_processes)
  171. assert all(x.min_concurrency <= i <= x.max_concurrency
  172. for i in total_num_processes)