| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261 | from __future__ import absolute_import, unicode_literalsfrom collections import defaultdictfrom contextlib import contextmanagerimport pytestfrom case import Mock, patchfrom kombu.utils.limits import TokenBucketfrom celery import Taskfrom celery.exceptions import InvalidTaskErrorfrom celery.utils.time import ratefrom celery.worker import statefrom celery.worker.request import Requestfrom celery.worker.strategy import default as default_strategyfrom celery.worker.strategy import proto1_to_proto2class test_proto1_to_proto2:    def setup(self):        self.message = Mock(name='message')        self.body = {            'args': (1,),            'kwargs': {'foo': 'baz'},            'utc': False,            'taskset': '123',        }    def test_message_without_args(self):        self.body.pop('args')        body, _, _, _ = proto1_to_proto2(self.message, self.body)        assert body[:2] == ((), {'foo': 'baz'})    def test_message_without_kwargs(self):        self.body.pop('kwargs')        body, _, _, _ = proto1_to_proto2(self.message, self.body)        assert body[:2] == ((1,), {})    def test_message_kwargs_not_mapping(self):        self.body['kwargs'] = (2,)        with pytest.raises(InvalidTaskError):            proto1_to_proto2(self.message, self.body)    def test_message_no_taskset_id(self):        self.body.pop('taskset')        assert proto1_to_proto2(self.message, self.body)    def test_message(self):        body, headers, decoded, utc = proto1_to_proto2(self.message, self.body)        assert body == ((1,), {'foo': 'baz'}, {            'callbacks': None, 'errbacks': None, 'chord': None, 'chain': None,        })        assert headers == dict(self.body, group='123')        assert decoded        assert not utcclass test_default_strategy_proto2:    def setup(self):        @self.app.task(shared=False)        def add(x, y):            return x + y        self.add = add    def get_message_class(self):        return self.TaskMessage    def prepare_message(self, message):        return message    class Context(object):        def __init__(self, sig, s, reserved, consumer, message):            self.sig = sig            self.s = s            self.reserved = reserved            self.consumer = consumer            self.message = message        def __call__(self, callbacks=[], **kwargs):            return self.s(                self.message,                (self.message.payload                    if not self.message.headers.get('id') else None),                self.message.ack, self.message.reject, callbacks, **kwargs            )        def was_reserved(self):            return self.reserved.called        def was_rate_limited(self):            assert not self.was_reserved()            return self.consumer._limit_task.called        def was_limited_with_eta(self):            assert not self.was_reserved()            called = self.consumer.timer.call_at.called            if called:                assert self.consumer.timer.call_at.call_args[0][1] == \                    self.consumer._limit_post_eta            return called        def was_scheduled(self):            assert not self.was_reserved()            assert not self.was_rate_limited()            return self.consumer.timer.call_at.called        def event_sent(self):            return self.consumer.event_dispatcher.send.call_args        def get_request(self):            if self.was_reserved():                return self.reserved.call_args[0][0]            if self.was_rate_limited():                return self.consumer._limit_task.call_args[0][0]            if self.was_scheduled():                return self.consumer.timer.call_at.call_args[0][0]            raise ValueError('request not handled')    @contextmanager    def _context(self, sig,                 rate_limits=True, events=True, utc=True, limit=None):        assert sig.type.Strategy        assert sig.type.Request        reserved = Mock()        consumer = Mock()        consumer.task_buckets = defaultdict(lambda: None)        if limit:            bucket = TokenBucket(rate(limit), capacity=1)            consumer.task_buckets[sig.task] = bucket        consumer.controller.state.revoked = set()        consumer.disable_rate_limits = not rate_limits        consumer.event_dispatcher.enabled = events        s = sig.type.start_strategy(self.app, consumer, task_reserved=reserved)        assert s        message = self.task_message_from_sig(            self.app, sig, utc=utc, TaskMessage=self.get_message_class(),        )        message = self.prepare_message(message)        yield self.Context(sig, s, reserved, consumer, message)    def test_when_logging_disabled(self):        with patch('celery.worker.strategy.logger') as logger:            logger.isEnabledFor.return_value = False            with self._context(self.add.s(2, 2)) as C:                C()                logger.info.assert_not_called()    def test_task_strategy(self):        with self._context(self.add.s(2, 2)) as C:            C()            assert C.was_reserved()            req = C.get_request()            C.consumer.on_task_request.assert_called_with(req)            assert C.event_sent()    def test_callbacks(self):        with self._context(self.add.s(2, 2)) as C:            callbacks = [Mock(name='cb1'), Mock(name='cb2')]            C(callbacks=callbacks)            req = C.get_request()            for callback in callbacks:                callback.assert_called_with(req)    def test_when_events_disabled(self):        with self._context(self.add.s(2, 2), events=False) as C:            C()            assert C.was_reserved()            assert not C.event_sent()    def test_eta_task(self):        with self._context(self.add.s(2, 2).set(countdown=10)) as C:            C()            assert C.was_scheduled()            C.consumer.qos.increment_eventually.assert_called_with()    def test_eta_task_utc_disabled(self):        with self._context(self.add.s(2, 2).set(countdown=10), utc=False) as C:            C()            assert C.was_scheduled()            C.consumer.qos.increment_eventually.assert_called_with()    def test_when_rate_limited(self):        task = self.add.s(2, 2)        with self._context(task, rate_limits=True, limit='1/m') as C:            C()            assert C.was_rate_limited()    def test_when_rate_limited_with_eta(self):        task = self.add.s(2, 2).set(countdown=10)        with self._context(task, rate_limits=True, limit='1/m') as C:            C()            assert C.was_limited_with_eta()            C.consumer.qos.increment_eventually.assert_called_with()    def test_when_rate_limited__limits_disabled(self):        task = self.add.s(2, 2)        with self._context(task, rate_limits=False, limit='1/m') as C:            C()            assert C.was_reserved()    def test_when_revoked(self):        task = self.add.s(2, 2)        task.freeze()        try:            with self._context(task) as C:                C.consumer.controller.state.revoked.add(task.id)                state.revoked.add(task.id)                C()                with pytest.raises(ValueError):                    C.get_request()        finally:            state.revoked.discard(task.id)class test_default_strategy_proto1(test_default_strategy_proto2):    def get_message_class(self):        return self.TaskMessage1class test_default_strategy_proto1__no_utc(test_default_strategy_proto2):    def get_message_class(self):        return self.TaskMessage1    def prepare_message(self, message):        message.payload['utc'] = False        return messageclass test_custom_request_for_default_strategy(test_default_strategy_proto2):    def test_custom_request_gets_instantiated(self):        _MyRequest = Mock(name='MyRequest')        class MyRequest(Request):            def __init__(self, *args, **kwargs):                Request.__init__(self, *args, **kwargs)                _MyRequest()        class MyTask(Task):            Request = MyRequest        @self.app.task(base=MyTask)        def failed():            raise AssertionError        sig = failed.s()        with self._context(sig) as C:            task_message_handler = default_strategy(                failed,                self.app,                C.consumer            )            task_message_handler(C.message, None, None, None, None)            _MyRequest.assert_called()
 |