Forráskód Böngészése

Introducing Celery.amqp.queues: .add, .select_subset, .format

* celery.conf.get_queues is now celery.amqp.queues.
* celery.utils.format_queues is now celery.amqp.queues.format

setting queues::

    >>> celery.amqp.queues = {"celery": {"exchange": "foo"}}

selecting a subset of the queues (same as the -Q option to celeryd)::

    >>> celery.amqp.queues.select_subset(["a", "b", "c"], create_missing=True)

adding a new queue:

    >>> celery.amqp.queues.add("foo", exchange="foo", routing_key="foo")
Ask Solem 14 éve
szülő
commit
d43d9646ba

+ 5 - 0
celery/__init__.py

@@ -1,4 +1,5 @@
 """Distributed Task Queue"""
+import os
 
 VERSION = (2, 1, 0, "a5")
 
@@ -12,3 +13,7 @@ __docformat__ = "restructuredtext"
 def Celery(*args, **kwargs):
     from celery import app
     return app.App(*args, **kwargs)
+
+
+def CompatCelery(*args, **kwargs):
+    return Celery(loader=os.environ.get("CELERY_LOADER", "default"))

+ 77 - 25
celery/app/amqp.py

@@ -1,5 +1,6 @@
 
 from datetime import datetime, timedelta
+from UserDict import UserDict
 
 from carrot.connection import BrokerConnection
 from carrot import messaging
@@ -9,6 +10,7 @@ from celery import signals
 from celery.utils import gen_unique_id, mitemgetter, textindent
 
 
+
 MSG_OPTIONS = ("mandatory", "priority", "immediate",
                "routing_key", "serializer", "delivery_mode")
 QUEUE_FORMAT = """
@@ -25,6 +27,59 @@ _queues_declared = False
 _exchanges_declared = set()
 
 
+class Queues(UserDict):
+
+    def __init__(self, queues):
+        self.data = {}
+        for queue_name, options in (queues or {}).items():
+            self.add(queue_name, **options)
+
+    def add(self, queue, exchange=None, routing_key=None,
+            exchange_type="direct", **options):
+        q = self[queue] = self.options(exchange, routing_key,
+                                       exchange_type, **options)
+        return q
+
+    def options(self, exchange, routing_key,
+            exchange_type="direct", **options):
+        return dict(options, routing_key=routing_key,
+                             binding_key=routing_key,
+                             exchange=exchange,
+                             exchange_type=exchange_type)
+
+    def format(self, indent=0):
+        """Format routing table into string for log dumps."""
+        format = lambda **queue: QUEUE_FORMAT.strip() % queue
+        info = "\n".join(format(name=name, **config)
+                                for name, config in self.items())
+        return textindent(info, indent=indent)
+
+    def select_subset(self, wanted, create_missing=True):
+        acc = {}
+        for queue in wanted:
+            try:
+                options = self[queue]
+            except KeyError:
+                if not create_missing:
+                    raise
+                options = self.options(queue, queue)
+            acc[queue] = options
+        self.data.clear()
+        self.data.update(acc)
+
+    @classmethod
+    def with_defaults(cls, queues, default_exchange, default_exchange_type):
+
+        def _defaults(opts):
+            opts.setdefault("exchange", default_exchange),
+            opts.setdefault("exchange_type", default_exchange_type)
+            opts.setdefault("binding_key", default_exchange)
+            opts.setdefault("routing_key", opts.get("binding_key"))
+            return opts
+
+        map(_defaults, queues.values())
+        return cls(queues)
+
 
 
 class TaskPublisher(messaging.Publisher):
@@ -118,33 +173,22 @@ class AMQP(object):
     Publisher = messaging.Publisher
     Consumer = messaging.Consumer
     ConsumerSet = ConsumerSet
+    _queues = None
 
     def __init__(self, app):
         self.app = app
 
-    def get_queues(self):
-        c = self.app.conf
-        queues = c.CELERY_QUEUES
-
-        def _defaults(opts):
-            opts.setdefault("exchange", c.CELERY_DEFAULT_EXCHANGE),
-            opts.setdefault("exchange_type", c.CELERY_DEFAULT_EXCHANGE_TYPE)
-            opts.setdefault("binding_key", c.CELERY_DEFAULT_EXCHANGE)
-            opts.setdefault("routing_key", opts.get("binding_key"))
-            return opts
-
-        return dict((queue, _defaults(opts))
-                    for queue, opts in queues.items())
-
-    def get_default_queue(self):
-        q = self.app.conf.CELERY_DEFAULT_QUEUE
-        return q, self.get_queues()[q]
+    def Queues(self, queues):
+        return Queues.with_defaults(queues,
+                                    self.app.conf.CELERY_DEFAULT_EXCHANGE,
+                                    self.app.conf.CELERY_DEFAULT_EXCHANGE_TYPE)
 
     def Router(self, queues=None, create_missing=None):
         return routes.Router(self.app.conf.CELERY_ROUTES,
                              queues or self.app.conf.CELERY_QUEUES,
                              self.app.either("CELERY_CREATE_MISSING_QUEUES",
-                                             create_missing))
+                                             create_missing),
+                             app=self.app)
 
     def TaskConsumer(self, *args, **kwargs):
         default_queue_name, default_queue = self.get_default_queue()
@@ -173,7 +217,7 @@ class AMQP(object):
         return publisher
 
     def get_consumer_set(self, connection, queues=None, **options):
-        queues = queues or self.get_queues()
+        queues = queues or self.queues
 
         cset = self.ConsumerSet(connection)
         for queue_name, queue_options in queues.items():
@@ -185,12 +229,9 @@ class AMQP(object):
             cset.consumers.append(consumer)
         return cset
 
-    def format_queues(self, queues, indent=0):
-        """Format routing table into string for log dumps."""
-        format = lambda **queue: QUEUE_FORMAT.strip() % queue
-        info = "\n".join(format(name=name, **config)
-                                for name, config in queues.items())
-        return textindent(info, indent=indent)
+    def get_default_queue(self):
+        q = self.app.conf.CELERY_DEFAULT_QUEUE
+        return q, self.queues[q]
 
     def get_broker_info(self):
         broker_connection = self.app.broker_connection()
@@ -216,3 +257,14 @@ class AMQP(object):
     def format_broker_info(self, info=None):
         """Get message broker connection info string for log dumps."""
         return BROKER_FORMAT % self.get_broker_info()
+
+    def _get_queues(self):
+        if self._queues is None:
+            c = self.app.conf
+            self._queues = self.Queues(c.CELERY_QUEUES)
+        return self._queues
+
+    def _set_queues(self, queues):
+        self._queues = self.Queues(queues)
+
+    queues = property(_get_queues, _set_queues)

+ 1 - 1
celery/apps/beat.py

@@ -7,7 +7,7 @@ from celery import beat
 from celery import platform
 from celery.app import app_or_default
 from celery.log import emergency_error
-from celery.utils import info, LOG_LEVELS
+from celery.utils import LOG_LEVELS
 
 STARTUP_INFO_FMT = """
 Configuration ->

+ 14 - 15
celery/apps/worker.py

@@ -11,7 +11,7 @@ from celery import platform
 from celery import signals
 from celery.app import app_or_default
 from celery.exceptions import ImproperlyConfigured
-from celery.utils import info, get_full_cls_name, LOG_LEVELS
+from celery.utils import get_full_cls_name, LOG_LEVELS
 from celery.worker import WorkController
 
 
@@ -99,20 +99,19 @@ class Worker(object):
         print("celery@%s has started." % self.hostname)
 
     def init_queues(self):
-        amqp = self.app.amqp
-        queues = amqp.get_queues()
         if self.use_queues:
-            queues = dict((queue, options)
-                                for queue, options in queues.items()
-                                    if queue in self.use_queues)
-            for queue in self.use_queues:
-                if queue not in queues:
-                    if self.app.conf.CELERY_CREATE_MISSING_QUEUES:
-                        amqp.Router(queues=queues).add_queue(queue)
-                    else:
-                        raise ImproperlyConfigured(
-                            "Queue '%s' not defined in CELERY_QUEUES" % queue)
-        self.queues = queues
+            create_missing = self.app.conf.CELERY_CREATE_MISSING_QUEUES
+            try:
+                self.app.amqp.queues.select_subset(self.use_queues,
+                                                   create_missing)
+            except KeyError, exc:
+                raise ImproperlyConfigured(
+                    "Trying to select queue subset of %r, but queue %s"
+                    "is not defined in CELERY_QUEUES. If you want to "
+                    "automatically declare unknown queues you have to "
+                    "enable CELERY_CREATE_MISSING_QUEUES" % (
+                        self.use_queues, exc))
+        self.queues = self.app.amqp.queues
 
     def init_loader(self):
         self.loader = self.app.loader
@@ -154,7 +153,7 @@ class Worker(object):
 
         return STARTUP_INFO_FMT % {
             "conninfo": self.app.amqp.format_broker_info(),
-            "queues": self.app.amqp.format_queues(self.queues, indent=8),
+            "queues": self.queues.format(indent=8),
             "concurrency": self.concurrency,
             "loglevel": LOG_LEVELS[self.loglevel],
             "logfile": self.logfile or "[stderr]",

+ 2 - 2
celery/bin/camqadm.py

@@ -14,7 +14,7 @@ from itertools import count
 from amqplib import client_0_8 as amqp
 from carrot.utils import partition
 
-from celery import Celery
+from celery import CompatCelery
 from celery.app import app_or_default
 from celery.utils import info
 from celery.utils import padlist
@@ -367,7 +367,7 @@ def parse_options(arguments):
 
 
 def camqadm(*args, **options):
-    options["app"] = Celery()
+    options["app"] = CompatCelery()
     return AMQPAdmin(*args, **options).run()
 
 

+ 2 - 2
celery/bin/celerybeat.py

@@ -22,7 +22,7 @@
     ``ERROR``, ``CRITICAL``, or ``FATAL``.
 
 """
-from celery import Celery
+from celery import CompatCelery
 from celery.bin.base import Command, Option
 
 
@@ -62,7 +62,7 @@ class BeatCommand(Command):
 
 
 def main():
-    app = Celery()
+    app = CompatCelery()
     beat = BeatCommand(app=app)
     beat.execute_from_commandline()
 

+ 2 - 2
celery/bin/celeryctl.py

@@ -8,7 +8,7 @@ from textwrap import wrap
 from anyjson import deserialize
 
 from celery import __version__
-from celery import Celery
+from celery import CompatCelery
 from celery.app import app_or_default
 from celery.utils import term
 
@@ -311,7 +311,7 @@ class celeryctl(object):
 
 def main():
     try:
-        app = Celery()
+        app = CompatCelery()
         celeryctl(app).execute_from_commandline()
     except KeyboardInterrupt:
         pass

+ 2 - 2
celery/bin/celeryd.py

@@ -70,7 +70,7 @@
 import multiprocessing
 
 from celery import __version__
-from celery import Celery
+from celery import CompatCelery
 from celery.bin.base import Command, Option
 
 
@@ -151,7 +151,7 @@ class WorkerCommand(Command):
 
 def main():
     multiprocessing.freeze_support()
-    app = Celery()
+    app = CompatCelery()
     worker = WorkerCommand(app=app)
     worker.execute_from_commandline()
 

+ 2 - 2
celery/bin/celeryev.py

@@ -3,7 +3,7 @@ import sys
 
 from optparse import OptionParser, make_option as Option
 
-from celery import Celery
+from celery import CompatCelery
 from celery.app import app_or_default
 from celery.events.cursesmon import evtop
 from celery.events.dumper import evdump
@@ -52,7 +52,7 @@ def parse_options(arguments):
 
 def main():
     options = parse_options(sys.argv[1:])
-    app = Celery()
+    app = CompatCelery()
     return run_celeryev(app=app, **vars(options))
 
 if __name__ == "__main__":

+ 0 - 19
celery/loaders/app.py

@@ -45,25 +45,6 @@ class AppLoader(BaseLoader):
     def on_worker_init(self):
         self.import_default_modules()
 
-    def import_from_cwd(self, module, imp=import_module):
-        """Import module, but make sure it finds modules
-        located in the current directory.
-
-        Modules located in the current directory has
-        precedence over modules located in ``sys.path``.
-        """
-        cwd = os.getcwd()
-        if cwd in sys.path:
-            return imp(module)
-        sys.path.insert(0, cwd)
-        try:
-            return imp(module)
-        finally:
-            try:
-                sys.path.remove(cwd)
-            except ValueError:
-                pass
-
     @property
     def conf(self):
         return self._conf

+ 27 - 3
celery/loaders/base.py

@@ -1,4 +1,7 @@
-from importlib import import_module
+import os
+import sys
+
+from importlib import import_module as _import_module
 
 BUILTIN_MODULES = ["celery.task"]
 
@@ -41,10 +44,10 @@ class BaseLoader(object):
         pass
 
     def import_task_module(self, module):
-        return self.import_module(module)
+        return self.import_from_cwd(module)
 
     def import_module(self, module):
-        return import_module(module)
+        return _import_module(module)
 
     def import_default_modules(self):
         imports = self.conf.get("CELERY_IMPORTS") or []
@@ -62,3 +65,24 @@ class BaseLoader(object):
         if not self._conf_cache:
             self._conf_cache = self.read_configuration()
         return self._conf_cache
+
+    def import_from_cwd(self, module, imp=None):
+        """Import module, but make sure it finds modules
+        located in the current directory.
+
+        Modules located in the current directory has
+        precedence over modules located in ``sys.path``.
+        """
+        if imp is None:
+            imp = self.import_module
+        cwd = os.getcwd()
+        if cwd in sys.path:
+            return imp(module)
+        sys.path.insert(0, cwd)
+        try:
+            return imp(module)
+        finally:
+            try:
+                sys.path.remove(cwd)
+            except ValueError:
+                pass

+ 5 - 9
celery/routes.py

@@ -18,22 +18,18 @@ class MapRoute(object):
 
 class Router(object):
 
-    def __init__(self, routes=None, queues=None, create_missing=False):
+    def __init__(self, routes=None, queues=None, create_missing=False,
+            app=None):
+        from celery.app import app_or_default
         if queues is None:
             queues = {}
         if routes is None:
             routes = []
+        self.app = app_or_default(app)
         self.queues = queues
         self.routes = routes
         self.create_missing = create_missing
 
-    def add_queue(self, queue):
-        q = self.queues[queue] = {"binding_key": queue,
-                                  "routing_key": queue,
-                                  "exchange": queue,
-                                  "exchange_type": "direct"}
-        return q
-
     def route(self, options, task, args=(), kwargs={}):
         # Expand "queue" keys in options.
         options = self.expand_destination(options)
@@ -59,7 +55,7 @@ class Router(object):
                 dest = dict(self.queues[queue])
             except KeyError:
                 if self.create_missing:
-                    dest = self.add_queue(queue)
+                    dest = self.app.amqp.queues.add(queue, queue, queue)
                 else:
                     raise QueueNotFound(
                         "Queue '%s' is not defined in CELERY_QUEUES" % queue)

+ 12 - 4
celery/tests/test_bin/test_celeryd.py

@@ -6,6 +6,7 @@ import unittest2 as unittest
 from multiprocessing import get_logger, current_process
 from StringIO import StringIO
 
+from celery import Celery
 from celery import platform
 from celery import signals
 from celery.app import default_app
@@ -25,7 +26,7 @@ def disable_stdouts(fun):
 
     @wraps(fun)
     def disable(*args, **kwargs):
-        sys.stdout, sys.stderr = StringIO(), StringIO()
+        #sys.stdout, sys.stderr = StringIO(), StringIO()
         try:
             return fun(*args, **kwargs)
         finally:
@@ -53,9 +54,11 @@ class test_Worker(unittest.TestCase):
 
     @disable_stdouts
     def test_queues_string(self):
-        worker = self.Worker(queues="foo,bar,baz")
+        celery = Celery()
+        worker = celery.Worker(queues="foo,bar,baz")
         worker.init_queues()
         self.assertEqual(worker.use_queues, ["foo", "bar", "baz"])
+        self.assertTrue("foo" in celery.amqp.queues)
 
     @disable_stdouts
     def test_loglevel_string(self):
@@ -121,7 +124,7 @@ class test_Worker(unittest.TestCase):
     @disable_stdouts
     def test_init_queues(self):
         c = default_app.conf
-        p, c.CELERY_QUEUES = c.CELERY_QUEUES, {
+        p, default_app.amqp.queues = default_app.amqp.queues, {
                 "celery": {"exchange": "celery",
                            "binding_key": "celery"},
                 "video": {"exchange": "video",
@@ -139,8 +142,13 @@ class test_Worker(unittest.TestCase):
             worker = self.Worker(queues=["image"])
             worker.init_queues()
             self.assertIn("image", worker.queues)
+            self.assertDictContainsSubset({"exchange": "image",
+                                           "routing_key": "image",
+                                           "binding_key": "image",
+                                           "exchange_type": "direct"},
+                                            worker.queues["image"])
         finally:
-            c.CELERY_QUEUES = p
+            default_app.amqp.queues = p
 
     @disable_stdouts
     def test_on_listener_ready(self):

+ 4 - 10
celery/tests/test_loaders.py

@@ -24,14 +24,8 @@ class TestLoaders(unittest.TestCase):
 
 class DummyLoader(base.BaseLoader):
 
-    class Config(object):
-
-        def __init__(self, **kwargs):
-            for attr, val in kwargs.items():
-                setattr(self, attr, val)
-
     def read_configuration(self):
-        return self.Config(foo="bar", CELERY_IMPORTS=("os", "sys"))
+        return {"foo": "bar", "CELERY_IMPORTS": ("os", "sys")}
 
 
 class TestLoaderBase(unittest.TestCase):
@@ -47,9 +41,9 @@ class TestLoaderBase(unittest.TestCase):
         self.assertEqual(sys, self.loader.import_task_module("sys"))
 
     def test_conf_property(self):
-        self.assertEqual(self.loader.conf.foo, "bar")
-        self.assertEqual(self.loader._conf_cache.foo, "bar")
-        self.assertEqual(self.loader.conf.foo, "bar")
+        self.assertEqual(self.loader.conf["foo"], "bar")
+        self.assertEqual(self.loader._conf_cache["foo"], "bar")
+        self.assertEqual(self.loader.conf["foo"], "bar")
 
     def test_import_default_modules(self):
         self.assertItemsEqual(self.loader.import_default_modules(),

+ 2 - 1
celery/tests/test_task.py

@@ -9,8 +9,9 @@ from celery import task
 from celery.app import default_app
 from celery.task.schedules import crontab, crontab_parser
 from celery.utils import timeutils
-from celery.utils import gen_unique_id, parse_iso8601
+from celery.utils import gen_unique_id
 from celery.utils.functional import wraps
+from celery.utils.timeutils import parse_iso8601
 from celery.result import EagerResult
 from celery.execute import send_task
 from celery.decorators import task as task_dec

+ 4 - 1
celery/tests/test_utils_info.py

@@ -1,5 +1,6 @@
 import unittest2 as unittest
 
+from celery import Celery
 from celery.app import default_app
 from celery.utils import textindent
 from celery.utils.timeutils import humanize_seconds
@@ -58,7 +59,9 @@ class TestInfo(unittest.TestCase):
         self.assertEqual(textindent(RANDTEXT, 4), RANDTEXT_RES)
 
     def test_format_queues(self):
-        self.assertEqual(default_app.amqp.format_queues(QUEUES), QUEUE_FORMAT)
+        celery = Celery()
+        celery.amqp.queues = QUEUES
+        self.assertEqual(celery.amqp.queues.format(), QUEUE_FORMAT)
 
     def test_broker_info(self):
         default_app.amqp.format_broker_info()