|
@@ -9,6 +9,7 @@ from kombu.connection import BrokerConnection
|
|
|
from celery.utils.timer2 import Timer
|
|
|
|
|
|
from celery.app import app_or_default
|
|
|
+from celery.concurrency.base import BasePool
|
|
|
from celery.decorators import task as task_dec
|
|
|
from celery.decorators import periodic_task as periodic_task_dec
|
|
|
from celery.serialization import pickle
|
|
@@ -117,7 +118,7 @@ class MockBackend(object):
|
|
|
self._acked = True
|
|
|
|
|
|
|
|
|
-class MockPool(object):
|
|
|
+class MockPool(BasePool):
|
|
|
_terminated = False
|
|
|
_stopped = False
|
|
|
|
|
@@ -436,13 +437,16 @@ class test_Consumer(unittest.TestCase):
|
|
|
l.broadcast_consumer = MockConsumer()
|
|
|
l.qos = _QoS()
|
|
|
l.connection = BrokerConnection()
|
|
|
+ l.iterations = 0
|
|
|
|
|
|
def raises_KeyError(limit=None):
|
|
|
- yield True
|
|
|
- l.iterations = 1
|
|
|
- raise KeyError("foo")
|
|
|
+ l.iterations += 1
|
|
|
+ if l.qos.prev != l.qos.next:
|
|
|
+ l.qos.update()
|
|
|
+ if l.iterations >= 2:
|
|
|
+ raise KeyError("foo")
|
|
|
|
|
|
- l._mainloop = raises_KeyError
|
|
|
+ l.consume_messages = raises_KeyError
|
|
|
self.assertRaises(KeyError, l.start)
|
|
|
self.assertTrue(called_back[0])
|
|
|
self.assertEqual(l.iterations, 1)
|
|
@@ -456,11 +460,10 @@ class test_Consumer(unittest.TestCase):
|
|
|
l.connection = BrokerConnection()
|
|
|
|
|
|
def raises_socket_error(limit=None):
|
|
|
- yield True
|
|
|
l.iterations = 1
|
|
|
raise socket.error("foo")
|
|
|
|
|
|
- l._mainloop = raises_socket_error
|
|
|
+ l.consume_messages = raises_socket_error
|
|
|
self.assertRaises(socket.error, l.start)
|
|
|
self.assertTrue(called_back[0])
|
|
|
self.assertEqual(l.iterations, 1)
|
|
@@ -509,8 +512,11 @@ class test_WorkController(unittest.TestCase):
|
|
|
m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
|
|
|
kwargs={})
|
|
|
task = TaskRequest.from_message(m, m.decode())
|
|
|
- worker.process_task(task)
|
|
|
- worker.pool.stop()
|
|
|
+ worker.components = []
|
|
|
+ worker._state = worker.RUN
|
|
|
+ self.assertRaises(KeyboardInterrupt, worker.process_task, task)
|
|
|
+ self.assertEqual(worker._state, worker.TERMINATE)
|
|
|
+
|
|
|
|
|
|
def test_process_task_raise_regular(self):
|
|
|
worker = self.worker
|