Browse Source

Tests passing

Ask Solem 13 years ago
parent
commit
a1f95e51cd

+ 35 - 48
celery/app/amqp.py

@@ -34,25 +34,47 @@ class Queues(dict):
     """Queue name⇒ declaration mapping.
 
     :param queues: Initial list/tuple or dict of queues.
+    :keyword create_missing: By default any unknown queues will be
+                             added automatically, but if disabled
+                             the occurrence of unknown queues
+                             in `wanted` will raise :exc:`KeyError`.
+
 
     """
     #: If set, this is a subset of queues to consume from.
     #: The rest of the queues are then used for routing only.
     _consume_from = None
 
-    def __init__(self, queues):
+    def __init__(self, queues, default_exchange=None, create_missing=True):
+        dict.__init__(self)
         self.aliases = WeakValueDictionary()
+        self.default_exchange = default_exchange
+        self.create_missing = create_missing
         if isinstance(queues, (tuple, list)):
             queues = dict((q.name, q) for q in queues)
-        dict.__init__(self)
-        for name, q in (queues or {}).items():
+        for name, q in (queues or {}).iteritems():
             self.add(q) if isinstance(q, Queue) else self.add_compat(name, **q)
 
     def __getitem__(self, name):
         try:
-            return dict.__getitem__(self, key)
+            return self.aliases[name]
         except KeyError:
-            return self.aliases[key]
+            return dict.__getitem__(self, name)
+
+    def __setitem__(self, name, queue):
+        if self.default_exchange:
+            if not queue.exchange or not queue.exchange.name:
+                queue.exchange = self.default_exchange
+            if queue.exchange.type == 'direct' and not queue.routing_key:
+                queue.routing_key = name
+        dict.__setitem__(self, name, queue)
+        if queue.alias:
+            self.aliases[queue.alias] = queue
+
+    def __missing__(self, name):
+        if self.create_missing:
+            return self.add(self.new_missing(name))
+        raise KeyError(name)
 
     def add(self, queue, **kwargs):
         """Add new queue.
@@ -67,23 +89,13 @@ class Queues(dict):
         if not isinstance(queue, Queue):
             return self.add_compat(queue, **kwargs)
         self[queue.name] = queue
-        if queue.alias:
-            self.aliases[queue.alias] = queue
         return queue
 
     def add_compat(self, name, **options):
         # docs used to use binding_key as routing key
         options.setdefault("routing_key", options.get("binding_key"))
-        self[name] = queue = entry_to_queue(name, **options)
-        return queue
-
-    def options(self, exchange, routing_key,
-            exchange_type="direct", **options):
-        """Creates new option mapping for queue, with required
-        keys present."""
-        return dict(options, routing_key=routing_key,
-                             exchange=exchange,
-                             exchange_type=exchange_type)
+        q = self[name] = entry_to_queue(name, **options)
+        return q
 
     def format(self, indent=0, indent_first=True):
         """Format routing table into string for log dumps."""
@@ -100,29 +112,14 @@ class Queues(dict):
             return text.indent("\n".join(info), indent)
         return info[0] + "\n" + text.indent("\n".join(info[1:]), indent)
 
-    def select_subset(self, wanted, create_missing=True):
+    def select_subset(self, wanted):
         """Sets :attr:`consume_from` by selecting a subset of the
         currently defined queues.
 
         :param wanted: List of wanted queue names.
-        :keyword create_missing: By default any unknown queues will be
-                                 added automatically, but if disabled
-                                 the occurrence of unknown queues
-                                 in `wanted` will raise :exc:`KeyError`.
-
         """
         if wanted:
-            acc = {}
-            for queue in wanted:
-                try:
-                    Q = self[queue]
-                except KeyError:
-                    if not create_missing:
-                        raise
-                    Q = self.new_missing(queue)
-                acc[queue] = Q
-            self._consume_from = acc
-            self.update(acc)
+            self._consume_from = dict((name, self[name]) for name in wanted)
 
     def new_missing(self, name):
         return Queue(name, Exchange(name), name)
@@ -133,18 +130,6 @@ class Queues(dict):
             return self._consume_from
         return self
 
-    @classmethod
-    def with_defaults(cls, queues, default_exchange):
-        """Alternate constructor that adds default exchange and
-        exchange type information to queues that does not have any."""
-        queues = cls(queues if queues is not None else {})
-        for q in queues.itervalues():
-            if not q.exchange or not q.exchange.name:
-                q.exchange = default_exchange
-            if not q.routing_key:
-                q.routing_key = default_exchange.name
-        return queues
-
 
 class TaskProducer(Producer):
     auto_declare = False
@@ -250,15 +235,17 @@ class AMQP(object):
     def flush_routes(self):
         self._rtable = _routes.prepare(self.app.conf.CELERY_ROUTES)
 
-    def Queues(self, queues):
+    def Queues(self, queues, create_missing=None):
         """Create new :class:`Queues` instance, using queue defaults
         from the current configuration."""
         conf = self.app.conf
+        if create_missing is None:
+            create_missing = conf.CELERY_CREATE_MISSING_QUEUES
         if not queues and conf.CELERY_DEFAULT_QUEUE:
             queues = (Queue(conf.CELERY_DEFAULT_QUEUE,
                             exchange=self.default_exchange,
                             routing_key=conf.CELERY_DEFAULT_ROUTING_KEY), )
-        return Queues.with_defaults(queues, self.default_exchange)
+        return Queues(queues, self.default_exchange, create_missing)
 
     def Router(self, queues=None, create_missing=None):
         """Returns the current task router."""

+ 1 - 2
celery/app/base.py

@@ -251,8 +251,7 @@ class Celery(object):
                                        use_tls=self.conf.EMAIL_USE_TLS)
 
     def select_queues(self, queues=None):
-        return self.amqp.queues.select_subset(queues,
-                                self.conf.CELERY_CREATE_MISSING_QUEUES)
+        return self.amqp.queues.select_subset(queues)
 
     def either(self, default_key, *values):
         """Fallback to the value of a configuration key if none of the

+ 2 - 3
celery/tests/app/test_amqp.py

@@ -3,6 +3,7 @@ from __future__ import with_statement
 
 from mock import Mock
 
+from celery.app.amqp import Queues
 from celery.tests.utils import AppCase
 
 
@@ -98,6 +99,4 @@ class test_Queues(AppCase):
             self.app.amqp.queues._consume_from = prev
 
     def test_with_defaults(self):
-        self.assertEqual(
-            self.app.amqp.queues.with_defaults(None,
-                self.app.amqp.default_exchange), {})
+        self.assertEqual(Queues(None), {})

+ 4 - 3
celery/tests/app/test_routes.py

@@ -82,7 +82,8 @@ class test_MapRoute(RouteCase):
         self.assertIsNone(route.route_for_task("celery.awesome"))
 
     def test_expand_route_not_found(self):
-        expand = E(current_app.amqp.queues)
+        expand = E(current_app.amqp.Queues(
+                    current_app.conf.CELERY_QUEUES, False))
         route = routes.MapRoute({"a": {"queue": "x"}})
         with self.assertRaises(QueueNotFound):
             expand(route.route_for_task("a"))
@@ -115,10 +116,10 @@ class test_lookup_route(RouteCase):
                               "immediate": False},
                              mytask.name,
                              args=[1, 2], kwargs={})
-        self.assertDictContainsSubset({"exchange": "testq",
-                                       "routing_key": "testq",
+        self.assertDictContainsSubset({"routing_key": "testq",
                                        "immediate": False},
                                        route)
+        self.assertEqual(route["exchange"].name, "testq")
         self.assertIn("queue", route)
 
     @with_queues(foo=a_queue, bar=b_queue)

+ 2 - 0
celery/tests/bin/test_celeryd.py

@@ -179,8 +179,10 @@ class test_Worker(AppCase):
             self.assertNotIn("celery", app.amqp.queues.consume_from)
 
             c.CELERY_CREATE_MISSING_QUEUES = False
+            del(app.amqp.queues)
             with self.assertRaises(ImproperlyConfigured):
                 self.Worker(queues=["image"]).init_queues()
+            del(app.amqp.queues)
             c.CELERY_CREATE_MISSING_QUEUES = True
             worker = self.Worker(queues=["image"])
             worker.init_queues()