Browse Source

100% coverage for celery.app.amqp

Ask Solem 14 years ago
parent
commit
b3763e8c6c
2 changed files with 93 additions and 4 deletions
  1. 2 1
      celery/app/amqp.py
  2. 91 3
      celery/tests/test_app/test_app_amqp.py

+ 2 - 1
celery/app/amqp.py

@@ -225,8 +225,9 @@ class TaskPublisher(messaging.Publisher):
         if chord:
             body["chord"] = chord
 
+        do_retry = retry if retry is not None else self.retry
         send = self.send
-        if retry is None and self.retry or retry:
+        if do_retry:
             send = connection.ensure(self, self.send, **_retry_policy)
         send(body, exchange=exchange, **extract_msg_options(kwargs))
         signals.task_sent.send(sender=task_name, **body)

+ 91 - 3
celery/tests/test_app/test_app_amqp.py

@@ -1,9 +1,13 @@
-from celery.tests.utils import unittest
+from __future__ import with_statement
 
-from celery.app.amqp import MSG_OPTIONS, extract_msg_options
+from mock import Mock
 
+from celery.tests.utils import AppCase
 
-class TestMsgOptions(unittest.TestCase):
+from celery.app.amqp import MSG_OPTIONS, extract_msg_options, TaskPublisher
+
+
+class TestMsgOptions(AppCase):
 
     def test_MSG_OPTIONS(self):
         self.assertTrue(MSG_OPTIONS)
@@ -13,3 +17,87 @@ class TestMsgOptions(unittest.TestCase):
         result = extract_msg_options(testing)
         self.assertEqual(result["mandatory"], True)
         self.assertEqual(result["routing_key"], "foo.xuzzy")
+
+
+class test_TaskPublisher(AppCase):
+
+    def test__exit__(self):
+
+        publisher = self.app.amqp.TaskPublisher(self.app.broker_connection())
+        publisher.close = Mock()
+        with publisher:
+            pass
+        publisher.close.assert_called_with()
+
+    def test_ensure_declare_queue(self, q="x1242112"):
+        publisher = self.app.amqp.TaskPublisher(Mock())
+        self.app.amqp.queues.add(q, q, q)
+        publisher._declare_queue(q, retry=True)
+        self.assertTrue(publisher.connection.ensure.call_count)
+
+    def test_ensure_declare_exchange(self, e="x9248311"):
+        publisher = self.app.amqp.TaskPublisher(Mock())
+        publisher._declare_exchange(e, "direct", retry=True)
+        self.assertTrue(publisher.connection.ensure.call_count)
+
+    def test_retry_policy(self):
+        pub = self.app.amqp.TaskPublisher(Mock())
+        pub.delay_task("tasks.add", (2, 2), {},
+                       retry_policy={"frobulate": 32.4})
+
+    def test_publish_no_retry(self):
+        pub = self.app.amqp.TaskPublisher(Mock())
+        pub.delay_task("tasks.add", (2, 2), {}, retry=False, chord=123)
+        self.assertFalse(pub.connection.ensure.call_count)
+
+
+class test_PublisherPool(AppCase):
+
+    def test_setup_nolimit(self):
+        L = self.app.conf.BROKER_POOL_LIMIT
+        self.app.conf.BROKER_POOL_LIMIT = None
+        try:
+            delattr(self.app, "_pool")
+        except AttributeError:
+            pass
+        self.app.amqp.__dict__.pop("publisher_pool", None)
+        try:
+            pool = self.app.amqp.publisher_pool
+            self.assertEqual(pool.limit, self.app.pool.limit)
+            self.assertFalse(pool._resource.queue)
+
+            r1 = pool.acquire()
+            r2 = pool.acquire()
+            r1.release()
+            r2.release()
+            r1 = pool.acquire()
+            r2 = pool.acquire()
+        finally:
+            self.app.conf.BROKER_POOL_LIMIT = L
+
+    def test_setup(self):
+        L = self.app.conf.BROKER_POOL_LIMIT
+        self.app.conf.BROKER_POOL_LIMIT = 2
+        try:
+            delattr(self.app, "_pool")
+        except AttributeError:
+            pass
+        self.app.amqp.__dict__.pop("publisher_pool", None)
+        try:
+            pool = self.app.amqp.publisher_pool
+            self.assertEqual(pool.limit, self.app.pool.limit)
+            self.assertTrue(pool._resource.queue)
+
+            p1 = r1 = pool.acquire()
+            p2 = r2 = pool.acquire()
+            delattr(r1.connection, "_publisher_chan")
+            r1.release()
+            r2.release()
+            r1 = pool.acquire()
+            r2 = pool.acquire()
+            self.assertIs(p2, r1)
+            self.assertIs(p1, r2)
+            r1.release()
+            r2.release()
+        finally:
+            self.app.conf.BROKER_POOL_LIMIT = L