Browse Source

Fixes issues related to producer queue declarations.

This patch changes producer queue declaration behavior to the following:

    * Queues are only declared when needed.

        Previously configured queues were only declared once: when
        the first message was sent.

    * Automatically configured queues are now declared in the same manner.

        These are queues created because of CREATE_MISSING_QUEUES.
        Kombu virtual transports requires the producer to declare all
        queues, as the routing table is kept in producer memory.
        For AMQP this means we won't lose messages if there are no one
        consuming from this queue.
Ask Solem 14 years ago
parent
commit
3cbd7cedc8
5 changed files with 33 additions and 45 deletions
  1. 19 17
      celery/app/amqp.py
  2. 3 1
      celery/routes.py
  3. 3 2
      celery/task/base.py
  4. 1 1
      celery/tests/test_routes.py
  5. 7 24
      celery/tests/test_task.py

+ 19 - 17
celery/app/amqp.py

@@ -35,6 +35,9 @@ binding:%(binding_key)s
 #: Set of exchange names that have already been declared.
 _exchanges_declared = set()
 
+#: Set of queue names that have already been declared.
+_queues_declared = set()
+
 
 def extract_msg_options(options, keep=MSG_OPTIONS):
     """Extracts known options to `basic_publish` from a dict,
@@ -125,27 +128,37 @@ class Queues(UserDict):
 
 
 class TaskPublisher(messaging.Publisher):
-    auto_declare = False
+    auto_declare = True
     retry = False
     retry_policy = None
 
     def __init__(self, *args, **kwargs):
+        self.app = kwargs.pop("app")
         self.retry = kwargs.pop("retry", self.retry)
         self.retry_policy = kwargs.pop("retry_policy",
                                         self.retry_policy or {})
         super(TaskPublisher, self).__init__(*args, **kwargs)
 
     def declare(self):
-        if self.exchange.name not in _exchanges_declared:
+        if self.exchange.name and \
+                self.exchange.name not in _exchanges_declared:
             super(TaskPublisher, self).declare()
             _exchanges_declared.add(self.exchange.name)
 
     def delay_task(self, task_name, task_args=None, task_kwargs=None,
             countdown=None, eta=None, task_id=None, taskset_id=None,
             expires=None, exchange=None, exchange_type=None,
-            event_dispatcher=None, retry=None, retry_policy=None, **kwargs):
+            event_dispatcher=None, retry=None, retry_policy=None,
+            queue=None, **kwargs):
         """Send task message."""
 
+        if queue and queue not in _queues_declared:
+            options = self.app.queues[queue]
+            entity = messaging.entry_to_queue(queue, **options)
+            entity(self.channel).declare()
+            _exchanges_declared.add(entity.exchange.name)
+            _queues_declared.add(entity.name)
+
         task_id = task_id or gen_unique_id()
         task_args = task_args or []
         task_kwargs = task_kwargs or {}
@@ -243,9 +256,6 @@ class AMQP(object):
     Consumer = messaging.Consumer
     ConsumerSet = messaging.ConsumerSet
 
-    #: Set to :const:`True` when the configured queues has been declared.
-    _queues_declared = False
-
     #: Cached and prepared routing table.
     _rtable = None
 
@@ -294,17 +304,9 @@ class AMQP(object):
                     "routing_key": conf.CELERY_DEFAULT_ROUTING_KEY,
                     "serializer": conf.CELERY_TASK_SERIALIZER,
                     "retry": conf.CELERY_TASK_PUBLISH_RETRY,
-                    "retry_policy": conf.CELERY_TASK_PUBLISH_RETRY_POLICY}
-        publisher = TaskPublisher(*args,
-                                  **self.app.merge(defaults, kwargs))
-
-        # Make sure all queues are declared.
-        if not self._queues_declared:
-            self.get_task_consumer(publisher.connection).close()
-            self._queues_declared = True
-        publisher.declare()
-
-        return publisher
+                    "retry_policy": conf.CELERY_TASK_PUBLISH_RETRY_POLICY,
+                    "app": self}
+        return TaskPublisher(*args, **self.app.merge(defaults, kwargs))
 
     def PublisherPool(self, limit=None):
         return PublisherPool(limit=limit, app=self.app)

+ 3 - 1
celery/routes.py

@@ -54,7 +54,9 @@ class Router(object):
                 if not self.create_missing:
                     raise QueueNotFound(
                         "Queue %r is not defined in CELERY_QUEUES" % queue)
-                dest = self.app.amqp.queues.add(queue, queue, queue)
+                dest = dict(self.app.amqp.queues.add(queue, queue, queue))
+            # needs to be declared by publisher
+            dest["queue"] = queue
             # routing_key and binding_key are synonyms.
             dest.setdefault("routing_key", dest.get("binding_key"))
             return lpmerge(dest, route)

+ 3 - 2
celery/task/base.py

@@ -259,7 +259,7 @@ class BaseTask(object):
 
     @classmethod
     def get_publisher(self, connection=None, exchange=None,
-            connect_timeout=None, exchange_type=None):
+            connect_timeout=None, exchange_type=None, **options):
         """Get a celery task message publisher.
 
         :rtype :class:`~celery.app.amqp.TaskPublisher`:
@@ -284,7 +284,8 @@ class BaseTask(object):
         return self.app.amqp.TaskPublisher(connection=connection,
                                            exchange=exchange,
                                            exchange_type=exchange_type,
-                                           routing_key=self.routing_key)
+                                           routing_key=self.routing_key,
+                                           **options)
 
     @classmethod
     def get_consumer(self, connection=None, connect_timeout=None):

+ 1 - 1
celery/tests/test_routes.py

@@ -93,7 +93,7 @@ class test_lookup_route(unittest.TestCase):
                                        "routing_key": "testq",
                                        "immediate": False},
                                        route)
-        self.assertNotIn("queue", route)
+        self.assertIn("queue", route)
 
     @with_queues(foo=a_queue, bar=b_queue)
     def test_expand_destaintion_string(self):

+ 7 - 24
celery/tests/test_task.py

@@ -198,17 +198,6 @@ class TestTaskRetries(unittest.TestCase):
         self.assertEqual(RetryTask.iterations, 2)
 
 
-class MockPublisher(object):
-    _declared = False
-
-    def __init__(self, *args, **kwargs):
-        self.kwargs = kwargs
-        self.connection = app_or_default().broker_connection()
-
-    def declare(self):
-        self._declared = True
-
-
 class TestCeleryTasks(unittest.TestCase):
 
     def test_unpickle_task(self):
@@ -354,19 +343,13 @@ class TestCeleryTasks(unittest.TestCase):
         self.assertTrue(dispatcher[0])
 
     def test_get_publisher(self):
-        from celery.app import amqp
-        old_pub = amqp.TaskPublisher
-        amqp.TaskPublisher = MockPublisher
-        try:
-            p = IncrementCounterTask.get_publisher(exchange="foo",
-                                                   connection="bar")
-            self.assertEqual(p.kwargs["exchange"], "foo")
-            self.assertTrue(p._declared)
-            p = IncrementCounterTask.get_publisher(exchange_type="fanout",
-                                                   connection="bar")
-            self.assertEqual(p.kwargs["exchange_type"], "fanout")
-        finally:
-            amqp.TaskPublisher = old_pub
+        connection = app_or_default().broker_connection()
+        p = IncrementCounterTask.get_publisher(connection, auto_declare=False,
+                                               exchange="foo")
+        self.assertEqual(p.exchange.name, "foo")
+        p = IncrementCounterTask.get_publisher(connection, auto_declare=False,
+                                               exchange_type="fanout")
+        self.assertEqual(p.exchange.type, "fanout")
 
     def test_update_state(self):