| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216 | from __future__ import absolute_import, unicode_literalsimport pytestfrom collections import defaultdictfrom contextlib import contextmanagerfrom case import Mock, patchfrom kombu.utils.limits import TokenBucketfrom celery.exceptions import InvalidTaskErrorfrom celery.worker import statefrom celery.worker.strategy import proto1_to_proto2from celery.utils.time import rateclass test_proto1_to_proto2:    def setup(self):        self.message = Mock(name='message')        self.body = {            'args': (1,),            'kwargs': {'foo': 'baz'},            'utc': False,            'taskset': '123',        }    def test_message_without_args(self):        self.body.pop('args')        with pytest.raises(InvalidTaskError):            proto1_to_proto2(self.message, self.body)    def test_message_without_kwargs(self):        self.body.pop('kwargs')        with pytest.raises(InvalidTaskError):            proto1_to_proto2(self.message, self.body)    def test_message_kwargs_not_mapping(self):        self.body['kwargs'] = (2,)        with pytest.raises(InvalidTaskError):            proto1_to_proto2(self.message, self.body)    def test_message_no_taskset_id(self):        self.body.pop('taskset')        assert proto1_to_proto2(self.message, self.body)    def test_message(self):        body, headers, decoded, utc = proto1_to_proto2(self.message, self.body)        assert body == ((1,), {'foo': 'baz'}, {            'callbacks': None, 'errbacks': None, 'chord': None, 'chain': None,        })        assert headers == dict(self.body, group='123')        assert decoded        assert not utcclass test_default_strategy_proto2:    def setup(self):        @self.app.task(shared=False)        def add(x, y):            return x + y        self.add = add    def get_message_class(self):        return self.TaskMessage    def prepare_message(self, message):        return message    class Context(object):        def __init__(self, sig, s, reserved, consumer, message):            self.sig = sig            self.s = s            self.reserved = reserved            self.consumer = consumer            self.message = message        def __call__(self, callbacks=[], **kwargs):            return self.s(                self.message,                (self.message.payload                    if not self.message.headers.get('id') else None),                self.message.ack, self.message.reject, callbacks, **kwargs            )        def was_reserved(self):            return self.reserved.called        def was_rate_limited(self):            assert not self.was_reserved()            return self.consumer._limit_task.called        def was_scheduled(self):            assert not self.was_reserved()            assert not self.was_rate_limited()            return self.consumer.timer.call_at.called        def event_sent(self):            return self.consumer.event_dispatcher.send.call_args        def get_request(self):            if self.was_reserved():                return self.reserved.call_args[0][0]            if self.was_rate_limited():                return self.consumer._limit_task.call_args[0][0]            if self.was_scheduled():                return self.consumer.timer.call_at.call_args[0][0]            raise ValueError('request not handled')    @contextmanager    def _context(self, sig,                 rate_limits=True, events=True, utc=True, limit=None):        assert sig.type.Strategy        reserved = Mock()        consumer = Mock()        consumer.task_buckets = defaultdict(lambda: None)        if limit:            bucket = TokenBucket(rate(limit), capacity=1)            consumer.task_buckets[sig.task] = bucket        consumer.controller.state.revoked = set()        consumer.disable_rate_limits = not rate_limits        consumer.event_dispatcher.enabled = events        s = sig.type.start_strategy(self.app, consumer, task_reserved=reserved)        assert s        message = self.task_message_from_sig(            self.app, sig, utc=utc, TaskMessage=self.get_message_class(),        )        message = self.prepare_message(message)        yield self.Context(sig, s, reserved, consumer, message)    def test_when_logging_disabled(self):        with patch('celery.worker.strategy.logger') as logger:            logger.isEnabledFor.return_value = False            with self._context(self.add.s(2, 2)) as C:                C()                logger.info.assert_not_called()    def test_task_strategy(self):        with self._context(self.add.s(2, 2)) as C:            C()            assert C.was_reserved()            req = C.get_request()            C.consumer.on_task_request.assert_called_with(req)            assert C.event_sent()    def test_callbacks(self):        with self._context(self.add.s(2, 2)) as C:            callbacks = [Mock(name='cb1'), Mock(name='cb2')]            C(callbacks=callbacks)            req = C.get_request()            for callback in callbacks:                callback.assert_called_with(req)    def test_when_events_disabled(self):        with self._context(self.add.s(2, 2), events=False) as C:            C()            assert C.was_reserved()            assert not C.event_sent()    def test_eta_task(self):        with self._context(self.add.s(2, 2).set(countdown=10)) as C:            C()            assert C.was_scheduled()            C.consumer.qos.increment_eventually.assert_called_with()    def test_eta_task_utc_disabled(self):        with self._context(self.add.s(2, 2).set(countdown=10), utc=False) as C:            C()            assert C.was_scheduled()            C.consumer.qos.increment_eventually.assert_called_with()    def test_when_rate_limited(self):        task = self.add.s(2, 2)        with self._context(task, rate_limits=True, limit='1/m') as C:            C()            assert C.was_rate_limited()    def test_when_rate_limited__limits_disabled(self):        task = self.add.s(2, 2)        with self._context(task, rate_limits=False, limit='1/m') as C:            C()            assert C.was_reserved()    def test_when_revoked(self):        task = self.add.s(2, 2)        task.freeze()        try:            with self._context(task) as C:                C.consumer.controller.state.revoked.add(task.id)                state.revoked.add(task.id)                C()                with pytest.raises(ValueError):                    C.get_request()        finally:            state.revoked.discard(task.id)class test_default_strategy_proto1(test_default_strategy_proto2):    def get_message_class(self):        return self.TaskMessage1class test_default_strategy_proto1__no_utc(test_default_strategy_proto2):    def get_message_class(self):        return self.TaskMessage1    def prepare_message(self, message):        message.payload['utc'] = False        return message
 |