|
@@ -8,9 +8,14 @@ from contextlib import contextmanager
|
|
from case import Mock, patch
|
|
from case import Mock, patch
|
|
from kombu.utils.limits import TokenBucket
|
|
from kombu.utils.limits import TokenBucket
|
|
|
|
|
|
|
|
+from celery import Task
|
|
from celery.exceptions import InvalidTaskError
|
|
from celery.exceptions import InvalidTaskError
|
|
from celery.worker import state
|
|
from celery.worker import state
|
|
-from celery.worker.strategy import proto1_to_proto2
|
|
|
|
|
|
+from celery.worker.strategy import (
|
|
|
|
+ proto1_to_proto2,
|
|
|
|
+ default as default_strategy
|
|
|
|
+)
|
|
|
|
+from celery.worker.request import Request
|
|
from celery.utils.time import rate
|
|
from celery.utils.time import rate
|
|
|
|
|
|
|
|
|
|
@@ -114,6 +119,7 @@ class test_default_strategy_proto2:
|
|
def _context(self, sig,
|
|
def _context(self, sig,
|
|
rate_limits=True, events=True, utc=True, limit=None):
|
|
rate_limits=True, events=True, utc=True, limit=None):
|
|
assert sig.type.Strategy
|
|
assert sig.type.Strategy
|
|
|
|
+ assert sig.type.Request
|
|
|
|
|
|
reserved = Mock()
|
|
reserved = Mock()
|
|
consumer = Mock()
|
|
consumer = Mock()
|
|
@@ -214,3 +220,30 @@ class test_default_strategy_proto1__no_utc(test_default_strategy_proto2):
|
|
def prepare_message(self, message):
|
|
def prepare_message(self, message):
|
|
message.payload['utc'] = False
|
|
message.payload['utc'] = False
|
|
return message
|
|
return message
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class test_custom_request_for_default_strategy(test_default_strategy_proto2):
|
|
|
|
+ def test_custom_request_gets_instantiated(self):
|
|
|
|
+ _MyRequest = Mock(name='MyRequest')
|
|
|
|
+
|
|
|
|
+ class MyRequest(Request):
|
|
|
|
+ def __init__(self, *args, **kwargs):
|
|
|
|
+ Request.__init__(self, *args, **kwargs)
|
|
|
|
+ _MyRequest()
|
|
|
|
+
|
|
|
|
+ class MyTask(Task):
|
|
|
|
+ Request = MyRequest
|
|
|
|
+
|
|
|
|
+ @self.app.task(base=MyTask)
|
|
|
|
+ def failed():
|
|
|
|
+ raise AssertionError
|
|
|
|
+
|
|
|
|
+ sig = failed.s()
|
|
|
|
+ with self._context(sig) as C:
|
|
|
|
+ task_message_handler = default_strategy(
|
|
|
|
+ failed,
|
|
|
|
+ self.app,
|
|
|
|
+ C.consumer
|
|
|
|
+ )
|
|
|
|
+ task_message_handler(C.message, None, None, None, None)
|
|
|
|
+ _MyRequest.assert_called()
|