Browse Source

Use kombu.common.maybe_declare

Ask Solem 13 years ago
parent
commit
994a8ae378
3 changed files with 35 additions and 35 deletions
  1. 28 32
      celery/app/amqp.py
  2. 5 3
      celery/tests/test_app/__init__.py
  3. 2 0
      celery/tests/test_app/test_app_amqp.py

+ 28 - 32
celery/app/amqp.py

@@ -16,6 +16,7 @@ from datetime import timedelta
 from kombu import BrokerConnection, Exchange
 from kombu import compat as messaging
 from kombu import pools
+from kombu.common import maybe_declare
 
 from celery import signals
 from celery.utils import cached_property, lpmerge, uuid
@@ -34,13 +35,6 @@ QUEUE_FORMAT = """
 binding:%(binding_key)s
 """
 
-#: Set of exchange names that have already been declared.
-_exchanges_declared = set()
-
-#: Set of queue names that have already been declared.
-_queues_declared = set()
-
-
 def extract_msg_options(options, keep=MSG_OPTIONS):
     """Extracts known options to `basic_publish` from a dict,
     and returns a new dict."""
@@ -151,6 +145,8 @@ class TaskPublisher(messaging.Publisher):
     auto_declare = False
     retry = False
     retry_policy = None
+    _queue_cache = {}
+    _exchange_cache = {}
 
     def __init__(self, *args, **kwargs):
         self.app = kwargs.pop("app")
@@ -161,26 +157,31 @@ class TaskPublisher(messaging.Publisher):
         super(TaskPublisher, self).__init__(*args, **kwargs)
 
     def declare(self):
-        if self.exchange.name and \
-                self.exchange.name not in _exchanges_declared:
+        if self.exchange.name and not declaration_cached(self.exchange):
             super(TaskPublisher, self).declare()
-            _exchanges_declared.add(self.exchange.name)
+
+    def _get_queue(self, name):
+        if name not in self._queue_cache:
+            options = self.app.amqp.queues[name]
+            self._queue_cache[name] = messaging.entry_to_queue(name, **options)
+        return self._queue_cache[name]
+
+    def _get_exchange(self, name, type=None):
+        if name not in self._exchange_cache:
+            self._exchange_cache[name] = Exchange(name,
+                type=type or self.exchange_type,
+                durable=self.durable,
+                auto_delete=self.auto_delete,
+            )
+        return self._exchange_cache[name]
 
     def _declare_queue(self, name, retry=False, retry_policy={}):
-        options = self.app.amqp.queues[name]
-        queue = messaging.entry_to_queue(name, **options)(self.channel)
-        if retry:
-            self.connection.ensure(queue, queue.declare, **retry_policy)()
-        else:
-            queue.declare()
-        return queue
-
-    def _declare_exchange(self, name, type, retry=False, retry_policy={}):
-        ex = Exchange(name, type=type, durable=self.durable,
-                      auto_delete=self.auto_delete)(self.channel)
-        if retry:
-            return self.connection.ensure(ex, ex.declare, **retry_policy)
-        return ex.declare()
+        maybe_declare(self._get_queue(name), self.channel,
+                      retry=retry, **retry_policy)
+
+    def _declare_exchange(self, name, type=None, retry=False, retry_policy={}):
+        maybe_declare(self._get_exchange(name, type), self.channel,
+                      retry=retry, **retry_policy)
 
     def delay_task(self, task_name, task_args=None, task_kwargs=None,
             countdown=None, eta=None, task_id=None, taskset_id=None,
@@ -196,14 +197,9 @@ class TaskPublisher(messaging.Publisher):
             _retry_policy = dict(_retry_policy, **retry_policy)
 
         # declare entities
-        if queue and queue not in _queues_declared:
-            entity = self._declare_queue(queue, retry, _retry_policy)
-            _exchanges_declared.add(entity.exchange.name)
-            _queues_declared.add(entity.name)
-        if exchange and exchange not in _exchanges_declared:
-            self._declare_exchange(exchange,
-                    exchange_type or self.exchange_type, retry, _retry_policy)
-            _exchanges_declared.add(exchange)
+        if queue:
+            self._declare_queue(queue, retry, _retry_policy)
+        self._declare_exchange(exchange, exchange_type, retry, _retry_policy)
 
         task_id = task_id or uuid()
         task_args = task_args or []

+ 5 - 3
celery/tests/test_app/__init__.py

@@ -227,22 +227,24 @@ class test_App(Case):
             chan.close()
         assert conn.transport_cls == "memory"
 
+        entities = conn.declared_entities
+
         pub = self.app.amqp.TaskPublisher(conn, exchange="foo_exchange")
-        self.assertNotIn("foo_exchange", amqp._exchanges_declared)
+        self.assertNotIn(pub._get_exchange("foo_exchange"), entities)
 
         dispatcher = Dispatcher()
         self.assertTrue(pub.delay_task("footask", (), {},
                                        exchange="moo_exchange",
                                        routing_key="moo_exchange",
                                        event_dispatcher=dispatcher))
-        self.assertIn("moo_exchange", amqp._exchanges_declared)
+        self.assertIn(pub._get_exchange("moo_exchange"), entities)
         self.assertTrue(dispatcher.sent)
         self.assertEqual(dispatcher.sent[0][0], "task-sent")
         self.assertTrue(pub.delay_task("footask", (), {},
                                        event_dispatcher=dispatcher,
                                        exchange="bar_exchange",
                                        routing_key="bar_exchange"))
-        self.assertIn("bar_exchange", amqp._exchanges_declared)
+        self.assertIn(pub._get_exchange("bar_exchange"), entities)
 
     def test_error_mail_sender(self):
         x = ErrorMail.subject % {"name": "task_name",

+ 2 - 0
celery/tests/test_app/test_app_amqp.py

@@ -42,11 +42,13 @@ class test_TaskPublisher(AppCase):
 
     def test_retry_policy(self):
         pub = self.app.amqp.TaskPublisher(Mock())
+        pub.channel.connection.client.declared_entities = set()
         pub.delay_task("tasks.add", (2, 2), {},
                        retry_policy={"frobulate": 32.4})
 
     def test_publish_no_retry(self):
         pub = self.app.amqp.TaskPublisher(Mock())
+        pub.channel.connection.client.declared_entities = set()
         pub.delay_task("tasks.add", (2, 2), {}, retry=False, chord=123)
         self.assertFalse(pub.connection.ensure.call_count)