|
@@ -14,7 +14,7 @@ from celery.datastructures import ExceptionInfo
|
|
from celery.exceptions import TimeoutError
|
|
from celery.exceptions import TimeoutError
|
|
from celery.utils import uuid
|
|
from celery.utils import uuid
|
|
|
|
|
|
-from celery.tests.utils import Case, sleepdeprived
|
|
|
|
|
|
+from celery.tests.utils import AppCase, sleepdeprived
|
|
|
|
|
|
|
|
|
|
class SomeClass(object):
|
|
class SomeClass(object):
|
|
@@ -23,7 +23,7 @@ class SomeClass(object):
|
|
self.data = data
|
|
self.data = data
|
|
|
|
|
|
|
|
|
|
-class test_AMQPBackend(Case):
|
|
|
|
|
|
+class test_AMQPBackend(AppCase):
|
|
|
|
|
|
def create_backend(self, **opts):
|
|
def create_backend(self, **opts):
|
|
opts = dict(dict(serializer="pickle", persistent=False), **opts)
|
|
opts = dict(dict(serializer="pickle", persistent=False), **opts)
|
|
@@ -101,35 +101,35 @@ class test_AMQPBackend(Case):
|
|
|
|
|
|
@sleepdeprived()
|
|
@sleepdeprived()
|
|
def test_store_result_retries(self):
|
|
def test_store_result_retries(self):
|
|
|
|
+ iterations = [0]
|
|
|
|
+ stop_raising_at = [5]
|
|
|
|
|
|
- class _Producer(object):
|
|
|
|
- iterations = 0
|
|
|
|
- stop_raising_at = 5
|
|
|
|
-
|
|
|
|
- def __init__(self, *args, **kwargs):
|
|
|
|
- pass
|
|
|
|
-
|
|
|
|
- def publish(self, msg, *args, **kwargs):
|
|
|
|
- if self.iterations > self.stop_raising_at:
|
|
|
|
- return
|
|
|
|
- raise KeyError("foo")
|
|
|
|
-
|
|
|
|
- class Backend(AMQPBackend):
|
|
|
|
- Producer = _Producer
|
|
|
|
|
|
+ def publish(*args, **kwargs):
|
|
|
|
+ if iterations[0] > stop_raising_at[0]:
|
|
|
|
+ return
|
|
|
|
+ iterations[0] += 1
|
|
|
|
+ raise KeyError("foo")
|
|
|
|
|
|
- backend = Backend()
|
|
|
|
- with self.assertRaises(KeyError):
|
|
|
|
- backend.store_result("foo", "bar", "STARTED", max_retries=None)
|
|
|
|
|
|
+ backend = AMQPBackend()
|
|
|
|
+ from celery.app.amqp import TaskProducer
|
|
|
|
+ prod, TaskProducer.publish = TaskProducer.publish, publish
|
|
|
|
+ try:
|
|
|
|
+ with self.assertRaises(KeyError):
|
|
|
|
+ backend.retry_policy["max_retries"] = None
|
|
|
|
+ backend.store_result("foo", "bar", "STARTED")
|
|
|
|
|
|
- with self.assertRaises(KeyError):
|
|
|
|
- backend.store_result("foo", "bar", "STARTED", max_retries=10)
|
|
|
|
|
|
+ with self.assertRaises(KeyError):
|
|
|
|
+ backend.retry_policy["max_retries"] = 10
|
|
|
|
+ backend.store_result("foo", "bar", "STARTED")
|
|
|
|
+ finally:
|
|
|
|
+ TaskProducer.publish = prod
|
|
|
|
|
|
def assertState(self, retval, state):
|
|
def assertState(self, retval, state):
|
|
self.assertEqual(retval["status"], state)
|
|
self.assertEqual(retval["status"], state)
|
|
|
|
|
|
def test_poll_no_messages(self):
|
|
def test_poll_no_messages(self):
|
|
b = self.create_backend()
|
|
b = self.create_backend()
|
|
- self.assertState(b.poll(uuid()), states.PENDING)
|
|
|
|
|
|
+ self.assertState(b.get_task_meta(uuid()), states.PENDING)
|
|
|
|
|
|
def test_poll_result(self):
|
|
def test_poll_result(self):
|
|
|
|
|
|
@@ -167,7 +167,7 @@ class test_AMQPBackend(Case):
|
|
results.put(Message(status=states.RECEIVED, seq=1))
|
|
results.put(Message(status=states.RECEIVED, seq=1))
|
|
results.put(Message(status=states.STARTED, seq=2))
|
|
results.put(Message(status=states.STARTED, seq=2))
|
|
results.put(Message(status=states.FAILURE, seq=3))
|
|
results.put(Message(status=states.FAILURE, seq=3))
|
|
- r1 = backend.poll(uuid())
|
|
|
|
|
|
+ r1 = backend.get_task_meta(uuid())
|
|
self.assertDictContainsSubset({"status": states.FAILURE,
|
|
self.assertDictContainsSubset({"status": states.FAILURE,
|
|
"seq": 3}, r1,
|
|
"seq": 3}, r1,
|
|
"FFWDs to the last state")
|
|
"FFWDs to the last state")
|
|
@@ -175,14 +175,14 @@ class test_AMQPBackend(Case):
|
|
# Caches last known state.
|
|
# Caches last known state.
|
|
results.put(Message())
|
|
results.put(Message())
|
|
tid = uuid()
|
|
tid = uuid()
|
|
- backend.poll(tid)
|
|
|
|
|
|
+ backend.get_task_meta(tid)
|
|
self.assertIn(tid, backend._cache, "Caches last known state")
|
|
self.assertIn(tid, backend._cache, "Caches last known state")
|
|
|
|
|
|
# Returns cache if no new states.
|
|
# Returns cache if no new states.
|
|
results.queue.clear()
|
|
results.queue.clear()
|
|
assert not results.qsize()
|
|
assert not results.qsize()
|
|
backend._cache[tid] = "hello"
|
|
backend._cache[tid] = "hello"
|
|
- self.assertEqual(backend.poll(tid), "hello",
|
|
|
|
|
|
+ self.assertEqual(backend.get_task_meta(tid), "hello",
|
|
"Returns cache if no new states")
|
|
"Returns cache if no new states")
|
|
|
|
|
|
def test_wait_for(self):
|
|
def test_wait_for(self):
|
|
@@ -217,7 +217,7 @@ class test_AMQPBackend(Case):
|
|
b = self.create_backend()
|
|
b = self.create_backend()
|
|
with current_app.pool.acquire_channel(block=False) as (_, channel):
|
|
with current_app.pool.acquire_channel(block=False) as (_, channel):
|
|
binding = b._create_binding(uuid())
|
|
binding = b._create_binding(uuid())
|
|
- consumer = b._create_consumer(binding, channel)
|
|
|
|
|
|
+ consumer = b.Consumer(channel, binding, no_ack=True)
|
|
with self.assertRaises(socket.timeout):
|
|
with self.assertRaises(socket.timeout):
|
|
b.drain_events(Connection(), consumer, timeout=0.1)
|
|
b.drain_events(Connection(), consumer, timeout=0.1)
|
|
|
|
|
|
@@ -249,7 +249,7 @@ class test_AMQPBackend(Case):
|
|
|
|
|
|
class Backend(AMQPBackend):
|
|
class Backend(AMQPBackend):
|
|
|
|
|
|
- def _create_consumer(self, *args, **kwargs):
|
|
|
|
|
|
+ def Consumer(*args, **kwargs):
|
|
raise KeyError("foo")
|
|
raise KeyError("foo")
|
|
|
|
|
|
b = Backend()
|
|
b = Backend()
|