|
@@ -0,0 +1,162 @@
|
|
|
+from __future__ import absolute_import
|
|
|
+
|
|
|
+from collections import defaultdict
|
|
|
+from contextlib import contextmanager
|
|
|
+from datetime import timedelta
|
|
|
+from mock import Mock, patch
|
|
|
+
|
|
|
+from kombu.utils.limits import TokenBucket
|
|
|
+
|
|
|
+from celery import Celery
|
|
|
+from celery.worker import state
|
|
|
+from celery.utils.timeutils import rate
|
|
|
+
|
|
|
+from celery.tests.utils import AppCase
|
|
|
+
|
|
|
+
|
|
|
+def body_from_sig(app, sig, utc=True):
|
|
|
+ sig._freeze()
|
|
|
+ callbacks = sig.options.pop('link', None)
|
|
|
+ errbacks = sig.options.pop('link_error', None)
|
|
|
+ countdown = sig.options.pop('countdown', None)
|
|
|
+ if countdown:
|
|
|
+ sig.options['eta'] = app.now() + timedelta(seconds=countdown)
|
|
|
+ eta = sig.options.pop('eta', None)
|
|
|
+ eta = eta.isoformat() if eta else None
|
|
|
+ return {
|
|
|
+ 'task': sig.task,
|
|
|
+ 'id': sig.id,
|
|
|
+ 'args': sig.args,
|
|
|
+ 'kwargs': sig.kwargs,
|
|
|
+ 'callbacks': [dict(s) for s in callbacks] if callbacks else None,
|
|
|
+ 'errbacks': [dict(s) for s in errbacks] if errbacks else None,
|
|
|
+ 'eta': eta,
|
|
|
+ 'utc': utc,
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+class test_default_strategy(AppCase):
|
|
|
+
|
|
|
+ class Context(object):
|
|
|
+
|
|
|
+ def __init__(self, sig, s, reserved, consumer, message, body):
|
|
|
+ self.sig = sig
|
|
|
+ self.s = s
|
|
|
+ self.reserved = reserved
|
|
|
+ self.consumer = consumer
|
|
|
+ self.message = message
|
|
|
+ self.body = body
|
|
|
+
|
|
|
+ def __call__(self, **kwargs):
|
|
|
+ return self.s(self.message, self.body, self.message.ack, **kwargs)
|
|
|
+
|
|
|
+ def was_reserved(self):
|
|
|
+ return self.reserved.called
|
|
|
+
|
|
|
+ def was_rate_limited(self):
|
|
|
+ assert not self.was_reserved()
|
|
|
+ return self.consumer._limit_task.called
|
|
|
+
|
|
|
+ def was_scheduled(self):
|
|
|
+ assert not self.was_reserved()
|
|
|
+ assert not self.was_rate_limited()
|
|
|
+ return self.consumer.timer.apply_at.called
|
|
|
+
|
|
|
+ def event_sent(self):
|
|
|
+ return self.consumer.event_dispatcher.send.call_args
|
|
|
+
|
|
|
+ def get_request(self):
|
|
|
+ if self.was_reserved():
|
|
|
+ return self.reserved.call_args[0][0]
|
|
|
+ if self.was_rate_limited():
|
|
|
+ return self.consumer._limit_task.call_args[0][0]
|
|
|
+ if self.was_scheduled():
|
|
|
+ return self.consumer.timer.apply_at.call_args[0][0]
|
|
|
+ raise ValueError('request not handled')
|
|
|
+
|
|
|
+ def setup(self):
|
|
|
+ self.c = Celery(set_as_current=False)
|
|
|
+
|
|
|
+ @self.c.task()
|
|
|
+ def add(x, y):
|
|
|
+ return x + y
|
|
|
+
|
|
|
+ self.add = add
|
|
|
+
|
|
|
+ @contextmanager
|
|
|
+ def _context(self, sig,
|
|
|
+ rate_limits=True, events=True, utc=True, limit=None):
|
|
|
+ self.assertTrue(sig.type.Strategy)
|
|
|
+
|
|
|
+ reserved = Mock()
|
|
|
+ consumer = Mock()
|
|
|
+ consumer.task_buckets = defaultdict(lambda: None)
|
|
|
+ if limit:
|
|
|
+ bucket = TokenBucket(rate(limit), capacity=1)
|
|
|
+ consumer.task_buckets[sig.task] = bucket
|
|
|
+ consumer.disable_rate_limits = not rate_limits
|
|
|
+ consumer.event_dispatcher.enabled = events
|
|
|
+ s = sig.type.start_strategy(self.c, consumer, task_reserved=reserved)
|
|
|
+ self.assertTrue(s)
|
|
|
+
|
|
|
+ message = Mock()
|
|
|
+ body = body_from_sig(self.c, sig, utc=utc)
|
|
|
+
|
|
|
+ yield self.Context(sig, s, reserved, consumer, message, body)
|
|
|
+
|
|
|
+ def test_when_logging_disabled(self):
|
|
|
+ with patch('celery.worker.strategy.logger') as logger:
|
|
|
+ logger.isEnabledFor.return_value = False
|
|
|
+ with self._context(self.add.s(2, 2)) as C:
|
|
|
+ C()
|
|
|
+ self.assertFalse(logger.info.called)
|
|
|
+
|
|
|
+ def test_task_strategy(self):
|
|
|
+ with self._context(self.add.s(2, 2)) as C:
|
|
|
+ C()
|
|
|
+ self.assertTrue(C.was_reserved())
|
|
|
+ req = C.get_request()
|
|
|
+ C.consumer.handle_task.assert_called_with(req)
|
|
|
+ self.assertTrue(C.event_sent())
|
|
|
+
|
|
|
+ def test_when_events_disabled(self):
|
|
|
+ with self._context(self.add.s(2, 2), events=False) as C:
|
|
|
+ C()
|
|
|
+ self.assertTrue(C.was_reserved())
|
|
|
+ self.assertFalse(C.event_sent())
|
|
|
+
|
|
|
+ def test_eta_task(self):
|
|
|
+ with self._context(self.add.s(2, 2).set(countdown=10)) as C:
|
|
|
+ C()
|
|
|
+ self.assertTrue(C.was_scheduled())
|
|
|
+ C.consumer.qos.increment_eventually.assert_called_with()
|
|
|
+
|
|
|
+ def test_eta_task_utc_disabled(self):
|
|
|
+ with self._context(self.add.s(2, 2).set(countdown=10), utc=False) as C:
|
|
|
+ C()
|
|
|
+ self.assertTrue(C.was_scheduled())
|
|
|
+ C.consumer.qos.increment_eventually.assert_called_with()
|
|
|
+
|
|
|
+ def test_when_rate_limited(self):
|
|
|
+ task = self.add.s(2, 2)
|
|
|
+ with self._context(task, rate_limits=True, limit='1/m') as C:
|
|
|
+ C()
|
|
|
+ self.assertTrue(C.was_rate_limited())
|
|
|
+
|
|
|
+ def test_when_rate_limited__limits_disabled(self):
|
|
|
+ task = self.add.s(2, 2)
|
|
|
+ with self._context(task, rate_limits=False, limit='1/m') as C:
|
|
|
+ C()
|
|
|
+ self.assertTrue(C.was_reserved())
|
|
|
+
|
|
|
+ def test_when_revoked(self):
|
|
|
+ task = self.add.s(2, 2)
|
|
|
+ task._freeze()
|
|
|
+ state.revoked.add(task.id)
|
|
|
+ try:
|
|
|
+ with self._context(task) as C:
|
|
|
+ C()
|
|
|
+ with self.assertRaises(ValueError):
|
|
|
+ C.get_request()
|
|
|
+ finally:
|
|
|
+ state.revoked.discard(task.id)
|