Browse Source

Tests passes again

Ask Solem 14 years ago
parent
commit
5ca6bee22d
1 changed files with 11 additions and 10 deletions
  1. 11 10
      celery/tests/test_worker/test_worker.py

+ 11 - 10
celery/tests/test_worker/test_worker.py

@@ -8,7 +8,7 @@ from kombu.transport.base import Message
 from kombu.connection import BrokerConnection
 from celery.utils.timer2 import Timer
 
-from celery.app import app_or_default
+from celery import current_app
 from celery.concurrency.base import BasePool
 from celery.exceptions import SystemTerminate
 from celery.task import task as task_dec
@@ -52,6 +52,10 @@ class MyKombuConsumer(MainConsumer):
     broadcast_consumer = MockConsumer()
     task_consumer = MockConsumer()
 
+    def __init__(self, *args, **kwargs):
+        kwargs.setdefault("pool", BasePool(2))
+        super(MyKombuConsumer, self).__init__(*args, **kwargs)
+
     def restart_heartbeat(self):
         self.heart = None
 
@@ -234,7 +238,7 @@ class test_QoS(unittest.TestCase):
 
     def test_consumer_increment_decrement(self):
         consumer = self.MockConsumer()
-        qos = QoS(consumer, 10, app_or_default().log.get_default_logger())
+        qos = QoS(consumer, 10, current_app.log.get_default_logger())
         qos.update()
         self.assertEqual(qos.value, 10)
         self.assertEqual(consumer.prefetch_count, 10)
@@ -258,7 +262,7 @@ class test_Consumer(unittest.TestCase):
     def setUp(self):
         self.ready_queue = FastQueue()
         self.eta_schedule = Timer()
-        self.logger = app_or_default().log.get_default_logger()
+        self.logger = current_app.log.get_default_logger()
         self.logger.setLevel(0)
 
     def tearDown(self):
@@ -272,7 +276,7 @@ class test_Consumer(unittest.TestCase):
         self.assertEqual(info["prefetch_count"], 10)
         self.assertFalse(info["broker"])
 
-        l.connection = app_or_default().broker_connection()
+        l.connection = current_app.broker_connection()
         info = l.info
         self.assertTrue(info["broker"])
 
@@ -419,15 +423,14 @@ class test_Consumer(unittest.TestCase):
                 raise SyntaxError("bar")
 
         l = MockConsumer(self.ready_queue, self.eta_schedule, self.logger,
-                             send_events=False)
+                             send_events=False, pool=BasePool())
         l.connection_errors = (KeyError, )
         self.assertRaises(SyntaxError, l.start)
         l.heart.stop()
 
     def test_consume_messages(self):
-        app = app_or_default()
 
-        class Connection(app.broker_connection().__class__):
+        class Connection(current_app.broker_connection().__class__):
             obj = None
 
             def drain_events(self, **kwargs):
@@ -449,13 +452,11 @@ class test_Consumer(unittest.TestCase):
         l.connection = Connection()
         l.connection.obj = l
         l.task_consumer = Consumer()
-        l.broadcast_consumer = Consumer()
         l.qos = QoS(l.task_consumer, 10, l.logger)
 
         l.consume_messages()
         l.consume_messages()
         self.assertTrue(l.task_consumer.consuming)
-        self.assertTrue(l.broadcast_consumer.consuming)
         self.assertEqual(l.task_consumer.prefetch_count, 10)
 
         l.qos.decrement()
@@ -660,7 +661,7 @@ class test_WorkController(AppCase):
         ignored_signals = []
         reset_signals = []
         worker_init = [False]
-        default_app = app_or_default()
+        default_app = current_app
         app = Celery(loader="default", set_as_current=False)
 
         class Loader(object):