123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137 |
- from __future__ import absolute_import
- from collections import defaultdict
- from contextlib import contextmanager
- from mock import Mock, patch
- from kombu.utils.limits import TokenBucket
- from celery.worker import state
- from celery.utils.timeutils import rate
- from celery.tests.case import AppCase, body_from_sig
- class test_default_strategy(AppCase):
- def setup(self):
- @self.app.task(shared=False)
- def add(x, y):
- return x + y
- self.add = add
- class Context(object):
- def __init__(self, sig, s, reserved, consumer, message, body):
- self.sig = sig
- self.s = s
- self.reserved = reserved
- self.consumer = consumer
- self.message = message
- self.body = body
- def __call__(self, **kwargs):
- return self.s(self.message, self.body, self.message.ack, **kwargs)
- def was_reserved(self):
- return self.reserved.called
- def was_rate_limited(self):
- assert not self.was_reserved()
- return self.consumer._limit_task.called
- def was_scheduled(self):
- assert not self.was_reserved()
- assert not self.was_rate_limited()
- return self.consumer.timer.call_at.called
- def event_sent(self):
- return self.consumer.event_dispatcher.send.call_args
- def get_request(self):
- if self.was_reserved():
- return self.reserved.call_args[0][0]
- if self.was_rate_limited():
- return self.consumer._limit_task.call_args[0][0]
- if self.was_scheduled():
- return self.consumer.timer.call_at.call_args[0][0]
- raise ValueError('request not handled')
- @contextmanager
- def _context(self, sig,
- rate_limits=True, events=True, utc=True, limit=None):
- self.assertTrue(sig.type.Strategy)
- reserved = Mock()
- consumer = Mock()
- consumer.task_buckets = defaultdict(lambda: None)
- if limit:
- bucket = TokenBucket(rate(limit), capacity=1)
- consumer.task_buckets[sig.task] = bucket
- consumer.disable_rate_limits = not rate_limits
- consumer.event_dispatcher.enabled = events
- s = sig.type.start_strategy(self.app, consumer, task_reserved=reserved)
- self.assertTrue(s)
- message = Mock()
- body = body_from_sig(self.app, sig, utc=utc)
- yield self.Context(sig, s, reserved, consumer, message, body)
- def test_when_logging_disabled(self):
- with patch('celery.worker.strategy.logger') as logger:
- logger.isEnabledFor.return_value = False
- with self._context(self.add.s(2, 2)) as C:
- C()
- self.assertFalse(logger.info.called)
- def test_task_strategy(self):
- with self._context(self.add.s(2, 2)) as C:
- C()
- self.assertTrue(C.was_reserved())
- req = C.get_request()
- C.consumer.on_task_request.assert_called_with(req)
- self.assertTrue(C.event_sent())
- def test_when_events_disabled(self):
- with self._context(self.add.s(2, 2), events=False) as C:
- C()
- self.assertTrue(C.was_reserved())
- self.assertFalse(C.event_sent())
- def test_eta_task(self):
- with self._context(self.add.s(2, 2).set(countdown=10)) as C:
- C()
- self.assertTrue(C.was_scheduled())
- C.consumer.qos.increment_eventually.assert_called_with()
- def test_eta_task_utc_disabled(self):
- with self._context(self.add.s(2, 2).set(countdown=10), utc=False) as C:
- C()
- self.assertTrue(C.was_scheduled())
- C.consumer.qos.increment_eventually.assert_called_with()
- def test_when_rate_limited(self):
- task = self.add.s(2, 2)
- with self._context(task, rate_limits=True, limit='1/m') as C:
- C()
- self.assertTrue(C.was_rate_limited())
- def test_when_rate_limited__limits_disabled(self):
- task = self.add.s(2, 2)
- with self._context(task, rate_limits=False, limit='1/m') as C:
- C()
- self.assertTrue(C.was_reserved())
- def test_when_revoked(self):
- task = self.add.s(2, 2)
- task.freeze()
- state.revoked.add(task.id)
- try:
- with self._context(task) as C:
- C()
- with self.assertRaises(ValueError):
- C.get_request()
- finally:
- state.revoked.discard(task.id)
|