test_strategy.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. from __future__ import absolute_import
  2. from collections import defaultdict
  3. from contextlib import contextmanager
  4. from kombu.utils.limits import TokenBucket
  5. from celery.worker import state
  6. from celery.utils.timeutils import rate
  7. from celery.tests.case import AppCase, Mock, patch, body_from_sig
  8. class test_default_strategy(AppCase):
  9. def setup(self):
  10. @self.app.task(shared=False)
  11. def add(x, y):
  12. return x + y
  13. self.add = add
  14. class Context(object):
  15. def __init__(self, sig, s, reserved, consumer, message, body):
  16. self.sig = sig
  17. self.s = s
  18. self.reserved = reserved
  19. self.consumer = consumer
  20. self.message = message
  21. self.body = body
  22. def __call__(self, **kwargs):
  23. return self.s(
  24. self.message, self.body,
  25. self.message.ack, self.message.reject, [], **kwargs
  26. )
  27. def was_reserved(self):
  28. return self.reserved.called
  29. def was_rate_limited(self):
  30. assert not self.was_reserved()
  31. return self.consumer._limit_task.called
  32. def was_scheduled(self):
  33. assert not self.was_reserved()
  34. assert not self.was_rate_limited()
  35. return self.consumer.timer.call_at.called
  36. def event_sent(self):
  37. return self.consumer.event_dispatcher.send.call_args
  38. def get_request(self):
  39. if self.was_reserved():
  40. return self.reserved.call_args[0][0]
  41. if self.was_rate_limited():
  42. return self.consumer._limit_task.call_args[0][0]
  43. if self.was_scheduled():
  44. return self.consumer.timer.call_at.call_args[0][0]
  45. raise ValueError('request not handled')
  46. @contextmanager
  47. def _context(self, sig,
  48. rate_limits=True, events=True, utc=True, limit=None):
  49. self.assertTrue(sig.type.Strategy)
  50. reserved = Mock()
  51. consumer = Mock()
  52. consumer.task_buckets = defaultdict(lambda: None)
  53. if limit:
  54. bucket = TokenBucket(rate(limit), capacity=1)
  55. consumer.task_buckets[sig.task] = bucket
  56. consumer.disable_rate_limits = not rate_limits
  57. consumer.event_dispatcher.enabled = events
  58. s = sig.type.start_strategy(self.app, consumer, task_reserved=reserved)
  59. self.assertTrue(s)
  60. message = Mock()
  61. body = body_from_sig(self.app, sig, utc=utc)
  62. yield self.Context(sig, s, reserved, consumer, message, body)
  63. def test_when_logging_disabled(self):
  64. with patch('celery.worker.strategy.logger') as logger:
  65. logger.isEnabledFor.return_value = False
  66. with self._context(self.add.s(2, 2)) as C:
  67. C()
  68. self.assertFalse(logger.info.called)
  69. def test_task_strategy(self):
  70. with self._context(self.add.s(2, 2)) as C:
  71. C()
  72. self.assertTrue(C.was_reserved())
  73. req = C.get_request()
  74. C.consumer.on_task_request.assert_called_with(req)
  75. self.assertTrue(C.event_sent())
  76. def test_when_events_disabled(self):
  77. with self._context(self.add.s(2, 2), events=False) as C:
  78. C()
  79. self.assertTrue(C.was_reserved())
  80. self.assertFalse(C.event_sent())
  81. def test_eta_task(self):
  82. with self._context(self.add.s(2, 2).set(countdown=10)) as C:
  83. C()
  84. self.assertTrue(C.was_scheduled())
  85. C.consumer.qos.increment_eventually.assert_called_with()
  86. def test_eta_task_utc_disabled(self):
  87. with self._context(self.add.s(2, 2).set(countdown=10), utc=False) as C:
  88. C()
  89. self.assertTrue(C.was_scheduled())
  90. C.consumer.qos.increment_eventually.assert_called_with()
  91. def test_when_rate_limited(self):
  92. task = self.add.s(2, 2)
  93. with self._context(task, rate_limits=True, limit='1/m') as C:
  94. C()
  95. self.assertTrue(C.was_rate_limited())
  96. def test_when_rate_limited__limits_disabled(self):
  97. task = self.add.s(2, 2)
  98. with self._context(task, rate_limits=False, limit='1/m') as C:
  99. C()
  100. self.assertTrue(C.was_reserved())
  101. def test_when_revoked(self):
  102. task = self.add.s(2, 2)
  103. task.freeze()
  104. state.revoked.add(task.id)
  105. try:
  106. with self._context(task) as C:
  107. C()
  108. with self.assertRaises(ValueError):
  109. C.get_request()
  110. finally:
  111. state.revoked.discard(task.id)