瀏覽代碼

Use kombu.common.maybe_declare

Ask Solem 13 年之前
父節點
當前提交
994a8ae378
共有 3 個文件被更改,包括 35 次插入35 次删除
  1. 28 32
      celery/app/amqp.py
  2. 5 3
      celery/tests/test_app/__init__.py
  3. 2 0
      celery/tests/test_app/test_app_amqp.py

+ 28 - 32
celery/app/amqp.py

@@ -16,6 +16,7 @@ from datetime import timedelta
 from kombu import BrokerConnection, Exchange
 from kombu import BrokerConnection, Exchange
 from kombu import compat as messaging
 from kombu import compat as messaging
 from kombu import pools
 from kombu import pools
+from kombu.common import maybe_declare
 
 
 from celery import signals
 from celery import signals
 from celery.utils import cached_property, lpmerge, uuid
 from celery.utils import cached_property, lpmerge, uuid
@@ -34,13 +35,6 @@ QUEUE_FORMAT = """
 binding:%(binding_key)s
 binding:%(binding_key)s
 """
 """
 
 
-#: Set of exchange names that have already been declared.
-_exchanges_declared = set()
-
-#: Set of queue names that have already been declared.
-_queues_declared = set()
-
-
 def extract_msg_options(options, keep=MSG_OPTIONS):
 def extract_msg_options(options, keep=MSG_OPTIONS):
     """Extracts known options to `basic_publish` from a dict,
     """Extracts known options to `basic_publish` from a dict,
     and returns a new dict."""
     and returns a new dict."""
@@ -151,6 +145,8 @@ class TaskPublisher(messaging.Publisher):
     auto_declare = False
     auto_declare = False
     retry = False
     retry = False
     retry_policy = None
     retry_policy = None
+    _queue_cache = {}
+    _exchange_cache = {}
 
 
     def __init__(self, *args, **kwargs):
     def __init__(self, *args, **kwargs):
         self.app = kwargs.pop("app")
         self.app = kwargs.pop("app")
@@ -161,26 +157,31 @@ class TaskPublisher(messaging.Publisher):
         super(TaskPublisher, self).__init__(*args, **kwargs)
         super(TaskPublisher, self).__init__(*args, **kwargs)
 
 
     def declare(self):
     def declare(self):
-        if self.exchange.name and \
-                self.exchange.name not in _exchanges_declared:
+        if self.exchange.name and not declaration_cached(self.exchange):
             super(TaskPublisher, self).declare()
             super(TaskPublisher, self).declare()
-            _exchanges_declared.add(self.exchange.name)
+
+    def _get_queue(self, name):
+        if name not in self._queue_cache:
+            options = self.app.amqp.queues[name]
+            self._queue_cache[name] = messaging.entry_to_queue(name, **options)
+        return self._queue_cache[name]
+
+    def _get_exchange(self, name, type=None):
+        if name not in self._exchange_cache:
+            self._exchange_cache[name] = Exchange(name,
+                type=type or self.exchange_type,
+                durable=self.durable,
+                auto_delete=self.auto_delete,
+            )
+        return self._exchange_cache[name]
 
 
     def _declare_queue(self, name, retry=False, retry_policy={}):
     def _declare_queue(self, name, retry=False, retry_policy={}):
-        options = self.app.amqp.queues[name]
-        queue = messaging.entry_to_queue(name, **options)(self.channel)
-        if retry:
-            self.connection.ensure(queue, queue.declare, **retry_policy)()
-        else:
-            queue.declare()
-        return queue
-
-    def _declare_exchange(self, name, type, retry=False, retry_policy={}):
-        ex = Exchange(name, type=type, durable=self.durable,
-                      auto_delete=self.auto_delete)(self.channel)
-        if retry:
-            return self.connection.ensure(ex, ex.declare, **retry_policy)
-        return ex.declare()
+        maybe_declare(self._get_queue(name), self.channel,
+                      retry=retry, **retry_policy)
+
+    def _declare_exchange(self, name, type=None, retry=False, retry_policy={}):
+        maybe_declare(self._get_exchange(name, type), self.channel,
+                      retry=retry, **retry_policy)
 
 
     def delay_task(self, task_name, task_args=None, task_kwargs=None,
     def delay_task(self, task_name, task_args=None, task_kwargs=None,
             countdown=None, eta=None, task_id=None, taskset_id=None,
             countdown=None, eta=None, task_id=None, taskset_id=None,
@@ -196,14 +197,9 @@ class TaskPublisher(messaging.Publisher):
             _retry_policy = dict(_retry_policy, **retry_policy)
             _retry_policy = dict(_retry_policy, **retry_policy)
 
 
         # declare entities
         # declare entities
-        if queue and queue not in _queues_declared:
-            entity = self._declare_queue(queue, retry, _retry_policy)
-            _exchanges_declared.add(entity.exchange.name)
-            _queues_declared.add(entity.name)
-        if exchange and exchange not in _exchanges_declared:
-            self._declare_exchange(exchange,
-                    exchange_type or self.exchange_type, retry, _retry_policy)
-            _exchanges_declared.add(exchange)
+        if queue:
+            self._declare_queue(queue, retry, _retry_policy)
+        self._declare_exchange(exchange, exchange_type, retry, _retry_policy)
 
 
         task_id = task_id or uuid()
         task_id = task_id or uuid()
         task_args = task_args or []
         task_args = task_args or []

+ 5 - 3
celery/tests/test_app/__init__.py

@@ -227,22 +227,24 @@ class test_App(Case):
             chan.close()
             chan.close()
         assert conn.transport_cls == "memory"
         assert conn.transport_cls == "memory"
 
 
+        entities = conn.declared_entities
+
         pub = self.app.amqp.TaskPublisher(conn, exchange="foo_exchange")
         pub = self.app.amqp.TaskPublisher(conn, exchange="foo_exchange")
-        self.assertNotIn("foo_exchange", amqp._exchanges_declared)
+        self.assertNotIn(pub._get_exchange("foo_exchange"), entities)
 
 
         dispatcher = Dispatcher()
         dispatcher = Dispatcher()
         self.assertTrue(pub.delay_task("footask", (), {},
         self.assertTrue(pub.delay_task("footask", (), {},
                                        exchange="moo_exchange",
                                        exchange="moo_exchange",
                                        routing_key="moo_exchange",
                                        routing_key="moo_exchange",
                                        event_dispatcher=dispatcher))
                                        event_dispatcher=dispatcher))
-        self.assertIn("moo_exchange", amqp._exchanges_declared)
+        self.assertIn(pub._get_exchange("moo_exchange"), entities)
         self.assertTrue(dispatcher.sent)
         self.assertTrue(dispatcher.sent)
         self.assertEqual(dispatcher.sent[0][0], "task-sent")
         self.assertEqual(dispatcher.sent[0][0], "task-sent")
         self.assertTrue(pub.delay_task("footask", (), {},
         self.assertTrue(pub.delay_task("footask", (), {},
                                        event_dispatcher=dispatcher,
                                        event_dispatcher=dispatcher,
                                        exchange="bar_exchange",
                                        exchange="bar_exchange",
                                        routing_key="bar_exchange"))
                                        routing_key="bar_exchange"))
-        self.assertIn("bar_exchange", amqp._exchanges_declared)
+        self.assertIn(pub._get_exchange("bar_exchange"), entities)
 
 
     def test_error_mail_sender(self):
     def test_error_mail_sender(self):
         x = ErrorMail.subject % {"name": "task_name",
         x = ErrorMail.subject % {"name": "task_name",

+ 2 - 0
celery/tests/test_app/test_app_amqp.py

@@ -42,11 +42,13 @@ class test_TaskPublisher(AppCase):
 
 
     def test_retry_policy(self):
     def test_retry_policy(self):
         pub = self.app.amqp.TaskPublisher(Mock())
         pub = self.app.amqp.TaskPublisher(Mock())
+        pub.channel.connection.client.declared_entities = set()
         pub.delay_task("tasks.add", (2, 2), {},
         pub.delay_task("tasks.add", (2, 2), {},
                        retry_policy={"frobulate": 32.4})
                        retry_policy={"frobulate": 32.4})
 
 
     def test_publish_no_retry(self):
     def test_publish_no_retry(self):
         pub = self.app.amqp.TaskPublisher(Mock())
         pub = self.app.amqp.TaskPublisher(Mock())
+        pub.channel.connection.client.declared_entities = set()
         pub.delay_task("tasks.add", (2, 2), {}, retry=False, chord=123)
         pub.delay_task("tasks.add", (2, 2), {}, retry=False, chord=123)
         self.assertFalse(pub.connection.ensure.call_count)
         self.assertFalse(pub.connection.ensure.call_count)