test_autoscale.py 6.4 KB


  1. from __future__ import absolute_import, unicode_literals
  2. import sys
  3. from celery.concurrency.base import BasePool
  4. from celery.five import monotonic
  5. from celery.worker import state
  6. from celery.worker import autoscale
  7. from celery.tests.case import AppCase, Mock, patch, sleepdeprived
  8. from celery.utils.objects import Bunch
  9. class MockPool(BasePool):
  10. shrink_raises_exception = False
  11. shrink_raises_ValueError = False
  12. def __init__(self, *args, **kwargs):
  13. super(MockPool, self).__init__(*args, **kwargs)
  14. self._pool = Bunch(_processes=self.limit)
  15. def grow(self, n=1):
  16. self._pool._processes += n
  17. def shrink(self, n=1):
  18. if self.shrink_raises_exception:
  19. raise KeyError('foo')
  20. if self.shrink_raises_ValueError:
  21. raise ValueError('foo')
  22. self._pool._processes -= n
  23. @property
  24. def num_processes(self):
  25. return self._pool._processes
  26. class test_WorkerComponent(AppCase):
  27. def test_register_with_event_loop(self):
  28. parent = Mock(name='parent')
  29. parent.autoscale = True
  30. parent.consumer.on_task_message = set()
  31. w = autoscale.WorkerComponent(parent)
  32. self.assertIsNone(parent.autoscaler)
  33. self.assertTrue(w.enabled)
  34. hub = Mock(name='hub')
  35. w.create(parent)
  36. w.register_with_event_loop(parent, hub)
  37. self.assertIn(
  38. parent.autoscaler.maybe_scale,
  39. parent.consumer.on_task_message,
  40. )
  41. hub.call_repeatedly.assert_called_with(
  42. parent.autoscaler.keepalive, parent.autoscaler.maybe_scale,
  43. )
  44. parent.hub = hub
  45. hub.on_init = []
  46. w.instantiate = Mock()
  47. w.register_with_event_loop(parent, Mock(name='loop'))
  48. self.assertTrue(parent.consumer.on_task_message)
  49. class test_Autoscaler(AppCase):
  50. def setup(self):
  51. self.pool = MockPool(3)
  52. def test_stop(self):
  53. class Scaler(autoscale.Autoscaler):
  54. alive = True
  55. joined = False
  56. def is_alive(self):
  57. return self.alive
  58. def join(self, timeout=None):
  59. self.joined = True
  60. worker = Mock(name='worker')
  61. x = Scaler(self.pool, 10, 3, worker=worker)
  62. x._is_stopped.set()
  63. x.stop()
  64. self.assertTrue(x.joined)
  65. x.joined = False
  66. x.alive = False
  67. x.stop()
  68. self.assertFalse(x.joined)
  69. @sleepdeprived(autoscale)
  70. def test_body(self):
  71. worker = Mock(name='worker')
  72. x = autoscale.Autoscaler(self.pool, 10, 3, worker=worker)
  73. x.body()
  74. self.assertEqual(x.pool.num_processes, 3)
  75. for i in range(20):
  76. state.reserved_requests.add(i)
  77. x.body()
  78. x.body()
  79. self.assertEqual(x.pool.num_processes, 10)
  80. self.assertTrue(worker.consumer._update_prefetch_count.called)
  81. state.reserved_requests.clear()
  82. x.body()
  83. self.assertEqual(x.pool.num_processes, 10)
  84. x._last_scale_up = monotonic() - 10000
  85. x.body()
  86. self.assertEqual(x.pool.num_processes, 3)
  87. self.assertTrue(worker.consumer._update_prefetch_count.called)
  88. def test_run(self):
  89. class Scaler(autoscale.Autoscaler):
  90. scale_called = False
  91. def body(self):
  92. self.scale_called = True
  93. self._is_shutdown.set()
  94. worker = Mock(name='worker')
  95. x = Scaler(self.pool, 10, 3, worker=worker)
  96. x.run()
  97. self.assertTrue(x._is_shutdown.isSet())
  98. self.assertTrue(x._is_stopped.isSet())
  99. self.assertTrue(x.scale_called)
  100. def test_shrink_raises_exception(self):
  101. worker = Mock(name='worker')
  102. x = autoscale.Autoscaler(self.pool, 10, 3, worker=worker)
  103. x.scale_up(3)
  104. x._last_action = monotonic() - 10000
  105. x.pool.shrink_raises_exception = True
  106. x._shrink(1)
  107. @patch('celery.worker.autoscale.debug')
  108. def test_shrink_raises_ValueError(self, debug):
  109. worker = Mock(name='worker')
  110. x = autoscale.Autoscaler(self.pool, 10, 3, worker=worker)
  111. x.scale_up(3)
  112. x._last_scale_up = monotonic() - 10000
  113. x.pool.shrink_raises_ValueError = True
  114. x.scale_down(1)
  115. self.assertTrue(debug.call_count)
  116. def test_update_and_force(self):
  117. worker = Mock(name='worker')
  118. x = autoscale.Autoscaler(self.pool, 10, 3, worker=worker)
  119. self.assertEqual(x.processes, 3)
  120. x.force_scale_up(5)
  121. self.assertEqual(x.processes, 8)
  122. x.update(5, None)
  123. self.assertEqual(x.processes, 5)
  124. x.force_scale_down(3)
  125. self.assertEqual(x.processes, 2)
  126. x.update(None, 3)
  127. self.assertEqual(x.processes, 3)
  128. x.force_scale_down(1000)
  129. self.assertEqual(x.min_concurrency, 0)
  130. self.assertEqual(x.processes, 0)
  131. x.force_scale_up(1000)
  132. x.min_concurrency = 1
  133. x.force_scale_down(1)
  134. x.update(max=300, min=10)
  135. x.update(max=300, min=2)
  136. x.update(max=None, min=None)
  137. def test_info(self):
  138. worker = Mock(name='worker')
  139. x = autoscale.Autoscaler(self.pool, 10, 3, worker=worker)
  140. info = x.info()
  141. self.assertEqual(info['max'], 10)
  142. self.assertEqual(info['min'], 3)
  143. self.assertEqual(info['current'], 3)
  144. @patch('os._exit')
  145. def test_thread_crash(self, _exit):
  146. class _Autoscaler(autoscale.Autoscaler):
  147. def body(self):
  148. self._is_shutdown.set()
  149. raise OSError('foo')
  150. worker = Mock(name='worker')
  151. x = _Autoscaler(self.pool, 10, 3, worker=worker)
  152. stderr = Mock()
  153. p, sys.stderr = sys.stderr, stderr
  154. try:
  155. x.run()
  156. finally:
  157. sys.stderr = p
  158. _exit.assert_called_with(1)
  159. self.assertTrue(stderr.write.call_count)
  160. def test_no_negative_scale(self):
  161. total_num_processes = []
  162. worker = Mock(name='worker')
  163. x = autoscale.Autoscaler(self.pool, 10, 3, worker=worker)
  164. x.body() #the body func scales up or down
  165. for i in range(35):
  166. state.reserved_requests.add(i)
  167. x.body()
  168. total_num_processes.append(self.pool.num_processes)
  169. for i in range(35):
  170. state.reserved_requests.remove(i)
  171. x.body()
  172. total_num_processes.append(self.pool.num_processes)
  173. self. assertTrue(
  174. all(x.min_concurrency <= i <= x.max_concurrency for i in total_num_processes)
  175. )