| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261 | 
							- from __future__ import absolute_import, unicode_literals
 
- from collections import defaultdict
 
- from contextlib import contextmanager
 
- import pytest
 
- from case import Mock, patch
 
- from kombu.utils.limits import TokenBucket
 
- from celery import Task
 
- from celery.exceptions import InvalidTaskError
 
- from celery.utils.time import rate
 
- from celery.worker import state
 
- from celery.worker.request import Request
 
- from celery.worker.strategy import default as default_strategy
 
- from celery.worker.strategy import proto1_to_proto2
 
- class test_proto1_to_proto2:
 
-     def setup(self):
 
-         self.message = Mock(name='message')
 
-         self.body = {
 
-             'args': (1,),
 
-             'kwargs': {'foo': 'baz'},
 
-             'utc': False,
 
-             'taskset': '123',
 
-         }
 
-     def test_message_without_args(self):
 
-         self.body.pop('args')
 
-         body, _, _, _ = proto1_to_proto2(self.message, self.body)
 
-         assert body[:2] == ((), {'foo': 'baz'})
 
-     def test_message_without_kwargs(self):
 
-         self.body.pop('kwargs')
 
-         body, _, _, _ = proto1_to_proto2(self.message, self.body)
 
-         assert body[:2] == ((1,), {})
 
-     def test_message_kwargs_not_mapping(self):
 
-         self.body['kwargs'] = (2,)
 
-         with pytest.raises(InvalidTaskError):
 
-             proto1_to_proto2(self.message, self.body)
 
-     def test_message_no_taskset_id(self):
 
-         self.body.pop('taskset')
 
-         assert proto1_to_proto2(self.message, self.body)
 
-     def test_message(self):
 
-         body, headers, decoded, utc = proto1_to_proto2(self.message, self.body)
 
-         assert body == ((1,), {'foo': 'baz'}, {
 
-             'callbacks': None, 'errbacks': None, 'chord': None, 'chain': None,
 
-         })
 
-         assert headers == dict(self.body, group='123')
 
-         assert decoded
 
-         assert not utc
 
- class test_default_strategy_proto2:
 
-     def setup(self):
 
-         @self.app.task(shared=False)
 
-         def add(x, y):
 
-             return x + y
 
-         self.add = add
 
-     def get_message_class(self):
 
-         return self.TaskMessage
 
-     def prepare_message(self, message):
 
-         return message
 
-     class Context(object):
 
-         def __init__(self, sig, s, reserved, consumer, message):
 
-             self.sig = sig
 
-             self.s = s
 
-             self.reserved = reserved
 
-             self.consumer = consumer
 
-             self.message = message
 
-         def __call__(self, callbacks=[], **kwargs):
 
-             return self.s(
 
-                 self.message,
 
-                 (self.message.payload
 
-                     if not self.message.headers.get('id') else None),
 
-                 self.message.ack, self.message.reject, callbacks, **kwargs
 
-             )
 
-         def was_reserved(self):
 
-             return self.reserved.called
 
-         def was_rate_limited(self):
 
-             assert not self.was_reserved()
 
-             return self.consumer._limit_task.called
 
-         def was_limited_with_eta(self):
 
-             assert not self.was_reserved()
 
-             called = self.consumer.timer.call_at.called
 
-             if called:
 
-                 assert self.consumer.timer.call_at.call_args[0][1] == \
 
-                     self.consumer._limit_post_eta
 
-             return called
 
-         def was_scheduled(self):
 
-             assert not self.was_reserved()
 
-             assert not self.was_rate_limited()
 
-             return self.consumer.timer.call_at.called
 
-         def event_sent(self):
 
-             return self.consumer.event_dispatcher.send.call_args
 
-         def get_request(self):
 
-             if self.was_reserved():
 
-                 return self.reserved.call_args[0][0]
 
-             if self.was_rate_limited():
 
-                 return self.consumer._limit_task.call_args[0][0]
 
-             if self.was_scheduled():
 
-                 return self.consumer.timer.call_at.call_args[0][0]
 
-             raise ValueError('request not handled')
 
-     @contextmanager
 
-     def _context(self, sig,
 
-                  rate_limits=True, events=True, utc=True, limit=None):
 
-         assert sig.type.Strategy
 
-         assert sig.type.Request
 
-         reserved = Mock()
 
-         consumer = Mock()
 
-         consumer.task_buckets = defaultdict(lambda: None)
 
-         if limit:
 
-             bucket = TokenBucket(rate(limit), capacity=1)
 
-             consumer.task_buckets[sig.task] = bucket
 
-         consumer.controller.state.revoked = set()
 
-         consumer.disable_rate_limits = not rate_limits
 
-         consumer.event_dispatcher.enabled = events
 
-         s = sig.type.start_strategy(self.app, consumer, task_reserved=reserved)
 
-         assert s
 
-         message = self.task_message_from_sig(
 
-             self.app, sig, utc=utc, TaskMessage=self.get_message_class(),
 
-         )
 
-         message = self.prepare_message(message)
 
-         yield self.Context(sig, s, reserved, consumer, message)
 
-     def test_when_logging_disabled(self):
 
-         with patch('celery.worker.strategy.logger') as logger:
 
-             logger.isEnabledFor.return_value = False
 
-             with self._context(self.add.s(2, 2)) as C:
 
-                 C()
 
-                 logger.info.assert_not_called()
 
-     def test_task_strategy(self):
 
-         with self._context(self.add.s(2, 2)) as C:
 
-             C()
 
-             assert C.was_reserved()
 
-             req = C.get_request()
 
-             C.consumer.on_task_request.assert_called_with(req)
 
-             assert C.event_sent()
 
-     def test_callbacks(self):
 
-         with self._context(self.add.s(2, 2)) as C:
 
-             callbacks = [Mock(name='cb1'), Mock(name='cb2')]
 
-             C(callbacks=callbacks)
 
-             req = C.get_request()
 
-             for callback in callbacks:
 
-                 callback.assert_called_with(req)
 
-     def test_when_events_disabled(self):
 
-         with self._context(self.add.s(2, 2), events=False) as C:
 
-             C()
 
-             assert C.was_reserved()
 
-             assert not C.event_sent()
 
-     def test_eta_task(self):
 
-         with self._context(self.add.s(2, 2).set(countdown=10)) as C:
 
-             C()
 
-             assert C.was_scheduled()
 
-             C.consumer.qos.increment_eventually.assert_called_with()
 
-     def test_eta_task_utc_disabled(self):
 
-         with self._context(self.add.s(2, 2).set(countdown=10), utc=False) as C:
 
-             C()
 
-             assert C.was_scheduled()
 
-             C.consumer.qos.increment_eventually.assert_called_with()
 
-     def test_when_rate_limited(self):
 
-         task = self.add.s(2, 2)
 
-         with self._context(task, rate_limits=True, limit='1/m') as C:
 
-             C()
 
-             assert C.was_rate_limited()
 
-     def test_when_rate_limited_with_eta(self):
 
-         task = self.add.s(2, 2).set(countdown=10)
 
-         with self._context(task, rate_limits=True, limit='1/m') as C:
 
-             C()
 
-             assert C.was_limited_with_eta()
 
-             C.consumer.qos.increment_eventually.assert_called_with()
 
-     def test_when_rate_limited__limits_disabled(self):
 
-         task = self.add.s(2, 2)
 
-         with self._context(task, rate_limits=False, limit='1/m') as C:
 
-             C()
 
-             assert C.was_reserved()
 
-     def test_when_revoked(self):
 
-         task = self.add.s(2, 2)
 
-         task.freeze()
 
-         try:
 
-             with self._context(task) as C:
 
-                 C.consumer.controller.state.revoked.add(task.id)
 
-                 state.revoked.add(task.id)
 
-                 C()
 
-                 with pytest.raises(ValueError):
 
-                     C.get_request()
 
-         finally:
 
-             state.revoked.discard(task.id)
 
- class test_default_strategy_proto1(test_default_strategy_proto2):
 
-     def get_message_class(self):
 
-         return self.TaskMessage1
 
- class test_default_strategy_proto1__no_utc(test_default_strategy_proto2):
 
-     def get_message_class(self):
 
-         return self.TaskMessage1
 
-     def prepare_message(self, message):
 
-         message.payload['utc'] = False
 
-         return message
 
- class test_custom_request_for_default_strategy(test_default_strategy_proto2):
 
-     def test_custom_request_gets_instantiated(self):
 
-         _MyRequest = Mock(name='MyRequest')
 
-         class MyRequest(Request):
 
-             def __init__(self, *args, **kwargs):
 
-                 Request.__init__(self, *args, **kwargs)
 
-                 _MyRequest()
 
-         class MyTask(Task):
 
-             Request = MyRequest
 
-         @self.app.task(base=MyTask)
 
-         def failed():
 
-             raise AssertionError
 
-         sig = failed.s()
 
-         with self._context(sig) as C:
 
-             task_message_handler = default_strategy(
 
-                 failed,
 
-                 self.app,
 
-                 C.consumer
 
-             )
 
-             task_message_handler(C.message, None, None, None, None)
 
-             _MyRequest.assert_called()
 
 
  |