Przeglądaj źródła

Fixes issues related to producer queue declarations.

This patch changes producer queue declaration behavior to the following:

    * Queues are only declared when needed.

        Previously configured queues were only declared once: when
        the first message was sent.

    * Automatically configured queues are now declared in the same manner.

        These are queues created because of CREATE_MISSING_QUEUES.
        Kombu virtual transports requires the producer to declare all
        queues, as the routing table is kept in producer memory.
        For AMQP this means we won't lose messages if there are no one
        consuming from this queue.
Ask Solem 14 lat temu
rodzic
commit
3cbd7cedc8

+ 19 - 17
celery/app/amqp.py

@@ -35,6 +35,9 @@ binding:%(binding_key)s
 #: Set of exchange names that have already been declared.
 #: Set of exchange names that have already been declared.
 _exchanges_declared = set()
 _exchanges_declared = set()
 
 
+#: Set of queue names that have already been declared.
+_queues_declared = set()
+
 
 
 def extract_msg_options(options, keep=MSG_OPTIONS):
 def extract_msg_options(options, keep=MSG_OPTIONS):
     """Extracts known options to `basic_publish` from a dict,
     """Extracts known options to `basic_publish` from a dict,
@@ -125,27 +128,37 @@ class Queues(UserDict):
 
 
 
 
 class TaskPublisher(messaging.Publisher):
 class TaskPublisher(messaging.Publisher):
-    auto_declare = False
+    auto_declare = True
     retry = False
     retry = False
     retry_policy = None
     retry_policy = None
 
 
     def __init__(self, *args, **kwargs):
     def __init__(self, *args, **kwargs):
+        self.app = kwargs.pop("app")
         self.retry = kwargs.pop("retry", self.retry)
         self.retry = kwargs.pop("retry", self.retry)
         self.retry_policy = kwargs.pop("retry_policy",
         self.retry_policy = kwargs.pop("retry_policy",
                                         self.retry_policy or {})
                                         self.retry_policy or {})
         super(TaskPublisher, self).__init__(*args, **kwargs)
         super(TaskPublisher, self).__init__(*args, **kwargs)
 
 
     def declare(self):
     def declare(self):
-        if self.exchange.name not in _exchanges_declared:
+        if self.exchange.name and \
+                self.exchange.name not in _exchanges_declared:
             super(TaskPublisher, self).declare()
             super(TaskPublisher, self).declare()
             _exchanges_declared.add(self.exchange.name)
             _exchanges_declared.add(self.exchange.name)
 
 
     def delay_task(self, task_name, task_args=None, task_kwargs=None,
     def delay_task(self, task_name, task_args=None, task_kwargs=None,
             countdown=None, eta=None, task_id=None, taskset_id=None,
             countdown=None, eta=None, task_id=None, taskset_id=None,
             expires=None, exchange=None, exchange_type=None,
             expires=None, exchange=None, exchange_type=None,
-            event_dispatcher=None, retry=None, retry_policy=None, **kwargs):
+            event_dispatcher=None, retry=None, retry_policy=None,
+            queue=None, **kwargs):
         """Send task message."""
         """Send task message."""
 
 
+        if queue and queue not in _queues_declared:
+            options = self.app.queues[queue]
+            entity = messaging.entry_to_queue(queue, **options)
+            entity(self.channel).declare()
+            _exchanges_declared.add(entity.exchange.name)
+            _queues_declared.add(entity.name)
+
         task_id = task_id or gen_unique_id()
         task_id = task_id or gen_unique_id()
         task_args = task_args or []
         task_args = task_args or []
         task_kwargs = task_kwargs or {}
         task_kwargs = task_kwargs or {}
@@ -243,9 +256,6 @@ class AMQP(object):
     Consumer = messaging.Consumer
     Consumer = messaging.Consumer
     ConsumerSet = messaging.ConsumerSet
     ConsumerSet = messaging.ConsumerSet
 
 
-    #: Set to :const:`True` when the configured queues has been declared.
-    _queues_declared = False
-
     #: Cached and prepared routing table.
     #: Cached and prepared routing table.
     _rtable = None
     _rtable = None
 
 
@@ -294,17 +304,9 @@ class AMQP(object):
                     "routing_key": conf.CELERY_DEFAULT_ROUTING_KEY,
                     "routing_key": conf.CELERY_DEFAULT_ROUTING_KEY,
                     "serializer": conf.CELERY_TASK_SERIALIZER,
                     "serializer": conf.CELERY_TASK_SERIALIZER,
                     "retry": conf.CELERY_TASK_PUBLISH_RETRY,
                     "retry": conf.CELERY_TASK_PUBLISH_RETRY,
-                    "retry_policy": conf.CELERY_TASK_PUBLISH_RETRY_POLICY}
-        publisher = TaskPublisher(*args,
-                                  **self.app.merge(defaults, kwargs))
-
-        # Make sure all queues are declared.
-        if not self._queues_declared:
-            self.get_task_consumer(publisher.connection).close()
-            self._queues_declared = True
-        publisher.declare()
-
-        return publisher
+                    "retry_policy": conf.CELERY_TASK_PUBLISH_RETRY_POLICY,
+                    "app": self}
+        return TaskPublisher(*args, **self.app.merge(defaults, kwargs))
 
 
     def PublisherPool(self, limit=None):
     def PublisherPool(self, limit=None):
         return PublisherPool(limit=limit, app=self.app)
         return PublisherPool(limit=limit, app=self.app)

+ 3 - 1
celery/routes.py

@@ -54,7 +54,9 @@ class Router(object):
                 if not self.create_missing:
                 if not self.create_missing:
                     raise QueueNotFound(
                     raise QueueNotFound(
                         "Queue %r is not defined in CELERY_QUEUES" % queue)
                         "Queue %r is not defined in CELERY_QUEUES" % queue)
-                dest = self.app.amqp.queues.add(queue, queue, queue)
+                dest = dict(self.app.amqp.queues.add(queue, queue, queue))
+            # needs to be declared by publisher
+            dest["queue"] = queue
             # routing_key and binding_key are synonyms.
             # routing_key and binding_key are synonyms.
             dest.setdefault("routing_key", dest.get("binding_key"))
             dest.setdefault("routing_key", dest.get("binding_key"))
             return lpmerge(dest, route)
             return lpmerge(dest, route)

+ 3 - 2
celery/task/base.py

@@ -259,7 +259,7 @@ class BaseTask(object):
 
 
     @classmethod
     @classmethod
     def get_publisher(self, connection=None, exchange=None,
     def get_publisher(self, connection=None, exchange=None,
-            connect_timeout=None, exchange_type=None):
+            connect_timeout=None, exchange_type=None, **options):
         """Get a celery task message publisher.
         """Get a celery task message publisher.
 
 
         :rtype :class:`~celery.app.amqp.TaskPublisher`:
         :rtype :class:`~celery.app.amqp.TaskPublisher`:
@@ -284,7 +284,8 @@ class BaseTask(object):
         return self.app.amqp.TaskPublisher(connection=connection,
         return self.app.amqp.TaskPublisher(connection=connection,
                                            exchange=exchange,
                                            exchange=exchange,
                                            exchange_type=exchange_type,
                                            exchange_type=exchange_type,
-                                           routing_key=self.routing_key)
+                                           routing_key=self.routing_key,
+                                           **options)
 
 
     @classmethod
     @classmethod
     def get_consumer(self, connection=None, connect_timeout=None):
     def get_consumer(self, connection=None, connect_timeout=None):

+ 1 - 1
celery/tests/test_routes.py

@@ -93,7 +93,7 @@ class test_lookup_route(unittest.TestCase):
                                        "routing_key": "testq",
                                        "routing_key": "testq",
                                        "immediate": False},
                                        "immediate": False},
                                        route)
                                        route)
-        self.assertNotIn("queue", route)
+        self.assertIn("queue", route)
 
 
     @with_queues(foo=a_queue, bar=b_queue)
     @with_queues(foo=a_queue, bar=b_queue)
     def test_expand_destaintion_string(self):
     def test_expand_destaintion_string(self):

+ 7 - 24
celery/tests/test_task.py

@@ -198,17 +198,6 @@ class TestTaskRetries(unittest.TestCase):
         self.assertEqual(RetryTask.iterations, 2)
         self.assertEqual(RetryTask.iterations, 2)
 
 
 
 
-class MockPublisher(object):
-    _declared = False
-
-    def __init__(self, *args, **kwargs):
-        self.kwargs = kwargs
-        self.connection = app_or_default().broker_connection()
-
-    def declare(self):
-        self._declared = True
-
-
 class TestCeleryTasks(unittest.TestCase):
 class TestCeleryTasks(unittest.TestCase):
 
 
     def test_unpickle_task(self):
     def test_unpickle_task(self):
@@ -354,19 +343,13 @@ class TestCeleryTasks(unittest.TestCase):
         self.assertTrue(dispatcher[0])
         self.assertTrue(dispatcher[0])
 
 
     def test_get_publisher(self):
     def test_get_publisher(self):
-        from celery.app import amqp
-        old_pub = amqp.TaskPublisher
-        amqp.TaskPublisher = MockPublisher
-        try:
-            p = IncrementCounterTask.get_publisher(exchange="foo",
-                                                   connection="bar")
-            self.assertEqual(p.kwargs["exchange"], "foo")
-            self.assertTrue(p._declared)
-            p = IncrementCounterTask.get_publisher(exchange_type="fanout",
-                                                   connection="bar")
-            self.assertEqual(p.kwargs["exchange_type"], "fanout")
-        finally:
-            amqp.TaskPublisher = old_pub
+        connection = app_or_default().broker_connection()
+        p = IncrementCounterTask.get_publisher(connection, auto_declare=False,
+                                               exchange="foo")
+        self.assertEqual(p.exchange.name, "foo")
+        p = IncrementCounterTask.get_publisher(connection, auto_declare=False,
+                                               exchange_type="fanout")
+        self.assertEqual(p.exchange.type, "fanout")
 
 
     def test_update_state(self):
     def test_update_state(self):