Browse Source

100% Coverage for celery.app.*

Ask Solem 14 years ago
parent
commit
a6357948ab

+ 22 - 7
celery/app/__init__.py

@@ -130,15 +130,24 @@ class App(base.BaseApp):
         return inner_create_task_cls(**options)
 
     def __reduce__(self):
+        # Reduce only pickles the configuration changes,
+        # so the default configuration doesn't have to be passed
+        # between processes.
         return (_unpickle_app, (self.__class__,
                                 self.main,
                                 self.conf.changes,
                                 self.loader_cls,
-                                self.backend_cls))
+                                self.backend_cls,
+                                self.amqp_cls,
+                                self.events_cls,
+                                self.log_cls,
+                                self.control_cls))
 
 
-def _unpickle_app(cls, main, changes, loader, backend, set_as_current):
-    app = cls(main, loader=loader, backend=backend,
+def _unpickle_app(cls, main, changes, loader, backend, amqp,
+                  events, log, control):
+    app = cls(main, loader=loader, backend=backend, amqp=amqp,
+                    events=events, log=log, control=control,
                     set_as_current=False)
     app.conf.update(changes)
     return app
@@ -148,9 +157,15 @@ def _unpickle_app(cls, main, changes, loader, backend, set_as_current):
 default_loader = os.environ.get("CELERY_LOADER") or "default"
 
 #: Global fallback app instance.
-default_app = App(loader=default_loader, set_as_current=False)
+_default_app = None
 
-if os.environ.get("CELERY_TRACE_APP"):
+def _get_default_app():
+    global _default_app
+    if _default_app is None:
+        _default_app = App(loader=default_loader, set_as_current=False)
+    return _default_app
+
+if os.environ.get("CELERY_TRACE_APP"):  # pragma: no cover
 
     def app_or_default(app=None):
         from traceback import print_stack
@@ -164,7 +179,7 @@ if os.environ.get("CELERY_TRACE_APP"):
                 raise Exception("DEFAULT APP")
             print("-- RETURNING TO DEFAULT APP --")
             print_stack()
-            return default_app
+            return _get_default_app()
         return app
 else:
     def app_or_default(app=None):
@@ -176,5 +191,5 @@ else:
 
         """
         if app is None:
-            return getattr(_tls, "current_app", None) or default_app
+            return getattr(_tls, "current_app", None) or _get_default_app()
         return app

+ 1 - 0
celery/app/amqp.py

@@ -183,6 +183,7 @@ class TaskPublisher(messaging.Publisher):
                                           type=exchange_type,
                                           durable=self.durable,
                                           auto_delete=self.auto_delete)
+            _exchanges_declared.add(exchange)
         self.send(message_data, exchange=exchange,
                   **extract_msg_options(kwargs))
         signals.task_sent.send(sender=task_name, **message_data)

+ 45 - 34
celery/app/base.py

@@ -16,7 +16,7 @@ from datetime import timedelta
 from celery import routes
 from celery.app.defaults import DEFAULTS
 from celery.datastructures import ConfigurationView
-from celery.utils import noop, isatty, cached_property
+from celery.utils import instantiate, isatty, cached_property, maybe_promise
 from celery.utils.functional import wraps
 
 
@@ -26,11 +26,23 @@ class BaseApp(object):
     IS_OSX = SYSTEM == "Darwin"
     IS_WINDOWS = SYSTEM == "Windows"
 
+    amqp_cls = "celery.app.amqp.AMQP"
+    backend_cls = None
+    events_cls = "celery.events.Events"
+    loader_cls = "app"
+    log_cls = "celery.log.Logging"
+    control_cls = "celery.task.control.Control"
+
     def __init__(self, main=None, loader=None, backend=None,
+            amqp=None, events=None, log=None, control=None,
             set_as_current=True):
         self.main = main
-        self.loader_cls = loader or "app"
-        self.backend_cls = backend
+        self.amqp_cls = amqp or self.amqp_cls
+        self.backend_cls = backend or self.backend_cls
+        self.events_cls = events or self.events_cls
+        self.loader_cls = loader or self.loader_cls
+        self.log_cls = log or self.log_cls
+        self.control_cls = control or self.control_cls
         self.set_as_current = set_as_current
         self.on_init()
 
@@ -174,12 +186,13 @@ class BaseApp(object):
             timeout = kwargs.get("connect_timeout")
             kwargs["connection"] = conn = connection or \
                     self.broker_connection(connect_timeout=timeout)
-            close_connection = not connection and conn.close or noop
+            close_connection = not connection and conn.close or None
 
             try:
                 return fun(*args, **kwargs)
             finally:
-                close_connection()
+                if close_connection:
+                    close_connection()
         return _inner
 
     def pre_config_merge(self, c):
@@ -207,13 +220,14 @@ class BaseApp(object):
         if c.get("CELERYD_LOG_COLOR") is None:
             c["CELERYD_LOG_COLOR"] = not c.CELERYD_LOG_FILE and \
                                         isatty(sys.stderr)
-            if self.IS_WINDOWS:  # windows console doesn't support ANSI colors
-                c["CELERYD_LOG_COLOR"] = False
+        if self.IS_WINDOWS:  # windows console doesn't support ANSI colors
+            c["CELERYD_LOG_COLOR"] = False
         if isinstance(c.CELERY_TASK_RESULT_EXPIRES, int):
             c["CELERY_TASK_RESULT_EXPIRES"] = timedelta(
                     seconds=c.CELERY_TASK_RESULT_EXPIRES)
 
         # Install backend cleanup periodic task.
+        c.CELERYBEAT_SCHEDULE = maybe_promise(c.CELERYBEAT_SCHEDULE)
         if c.CELERY_TASK_RESULT_EXPIRES:
             from celery.schedules import crontab
             c.CELERYBEAT_SCHEDULE.setdefault("celery.backend_cleanup",
@@ -229,13 +243,14 @@ class BaseApp(object):
         if not self.conf.ADMINS:
             return
         to = [admin_email for _, admin_email in self.conf.ADMINS]
-        self.loader.mail_admins(subject, body, fail_silently,
-                                to=to, sender=self.conf.SERVER_EMAIL,
-                                host=self.conf.EMAIL_HOST,
-                                port=self.conf.EMAIL_PORT,
-                                user=self.conf.EMAIL_HOST_USER,
-                                password=self.conf.EMAIL_HOST_PASSWORD,
-                                timeout=self.conf.EMAIL_TIMEOUT)
+        return self.loader.mail_admins(subject, body, fail_silently,
+                                       to=to,
+                                       sender=self.conf.SERVER_EMAIL,
+                                       host=self.conf.EMAIL_HOST,
+                                       port=self.conf.EMAIL_PORT,
+                                       user=self.conf.EMAIL_HOST_USER,
+                                       password=self.conf.EMAIL_HOST_PASSWORD,
+                                       timeout=self.conf.EMAIL_TIMEOUT)
 
     def either(self, default_key, *values):
         """Fallback to the value of a configuration key if none of the
@@ -271,8 +286,7 @@ class BaseApp(object):
         See :class:`~celery.app.amqp.AMQP`.
 
         """
-        from celery.app.amqp import AMQP
-        return AMQP(self)
+        return instantiate(self.amqp_cls, app=self)
 
     @cached_property
     def backend(self):
@@ -283,12 +297,6 @@ class BaseApp(object):
         """
         return self._get_backend()
 
-    @cached_property
-    def loader(self):
-        """Current loader."""
-        from celery.loaders import get_loader_cls
-        return get_loader_cls(self.loader_cls)(app=self)
-
     @cached_property
     def conf(self):
         """Current configuration (dict and attribute access)."""
@@ -301,25 +309,28 @@ class BaseApp(object):
         See :class:`~celery.task.control.Control`.
 
         """
-        from celery.task.control import Control
-        return Control(app=self)
+        return instantiate(self.control_cls, app=self)
 
     @cached_property
-    def log(self):
-        """Logging utilities.
+    def events(self):
+        """Sending/receiving events.
 
-        See :class:`~celery.log.Logging`.
+        See :class:`~celery.events.Events`.
 
         """
-        from celery.log import Logging
-        return Logging(app=self)
+        return instantiate(self.events_cls, app=self)
 
     @cached_property
-    def events(self):
-        """Sending/receiving events.
+    def loader(self):
+        """Current loader."""
+        from celery.loaders import get_loader_cls
+        return get_loader_cls(self.loader_cls)(app=self)
 
-        See :class:`~celery.events.Events`.
+    @cached_property
+    def log(self):
+        """Logging utilities.
+
+        See :class:`~celery.log.Logging`.
 
         """
-        from celery.events import Events
-        return Events(app=self)
+        return instantiate(self.log_cls, app=self)

+ 3 - 1
celery/app/defaults.py

@@ -1,5 +1,7 @@
 from datetime import timedelta
 
+from celery.utils import promise
+
 DEFAULT_PROCESS_LOG_FMT = """
     [%(asctime)s: %(levelname)s/%(processName)s] %(message)s
 """.strip()
@@ -116,7 +118,7 @@ NAMESPACES = {
         "TASK_TIME_LIMIT": Option(type="int"),
     },
     "CELERYBEAT": {
-        "SCHEDULE": Option({}, type="dict"),
+        "SCHEDULE": Option(promise(lambda: {}), type="dict"),
         "SCHEDULER": Option("celery.beat.PersistentScheduler"),
         "SCHEDULE_FILENAME": Option("celerybeat-schedule"),
         "MAX_LOOP_INTERVAL": Option(5 * 60, type="int"),

+ 6 - 0
celery/datastructures.py

@@ -131,6 +131,12 @@ class ConfigurationView(AttributeDictMixin):
         return chain(*[d.iteritems() for d in (self.__dict__["changes"],
                                                self.__dict__["defaults"])])
 
+    def iteritems(self):
+        return iter(self)
+
+    def iter(self):
+        return tuple(iter(self))
+
 
 class PositionQueue(UserList):
     """A positional queue of a specific length, with slots that are either

+ 1 - 1
celery/events/__init__.py

@@ -212,7 +212,7 @@ class EventReceiver(object):
 
 class Events(object):
 
-    def __init__(self, app):
+    def __init__(self, app=None):
         self.app = app
 
     def Receiver(self, connection, handlers=None, routing_key="#"):

+ 0 - 28
celery/task/builtins.py

@@ -1,28 +0,0 @@
-from celery import conf
-from celery.schedules import crontab
-from celery.task.base import Task
-
-
-class backend_cleanup(Task):
-    name = "celery.backend_cleanup"
-
-    def run(self):
-        self.backend.cleanup()
-
-if conf.TASK_RESULT_EXPIRES and \
-        backend_cleanup.name not in conf.CELERYBEAT_SCHEDULE:
-    conf.CELERYBEAT_SCHEDULE[backend_cleanup.name] = dict(
-            task=backend_cleanup.name,
-            schedule=crontab(minute="00", hour="04", day_of_week="*"))
-
-
-DeleteExpiredTaskMetaTask = backend_cleanup         # FIXME remove in 3.0
-
-
-class PingTask(Task):
-    """The task used by :func:`ping`."""
-    name = "celery.ping"
-
-    def run(self, **kwargs):
-        """:returns: the string `"pong"`."""
-        return "pong"

+ 31 - 3
celery/tests/__init__.py

@@ -2,10 +2,12 @@ import logging
 import os
 import sys
 
-config = os.environ.setdefault("CELERY_TEST_CONFIG_MODULE",
-                               "celery.tests.config")
+from importlib import import_module
 
-os.environ["CELERY_CONFIG_MODULE"] = config
+config_module = os.environ.setdefault("CELERY_TEST_CONFIG_MODULE",
+                                      "celery.tests.config")
+
+os.environ["CELERY_CONFIG_MODULE"] = config_module
 os.environ["CELERY_LOADER"] = "default"
 os.environ["EVENTLET_NOPATCH"] = "yes"
 os.environ["GEVENT_NOPATCH"] = "yes"
@@ -25,3 +27,29 @@ def teardown():
         sys.stderr.write(
             "\n\n**WARNING**: Remaning threads at teardown: %r...\n" % (
                 remaining_threads))
+
+
+
+def find_distribution_modules(name=__name__, file=__file__):
+    current_dist_depth = len(name.split(".")) - 1
+    current_dist = os.path.join(os.path.dirname(file),
+                                *([os.pardir] * current_dist_depth))
+    abs = os.path.abspath(current_dist)
+    dist_name = os.path.basename(abs)
+
+    for dirpath, dirnames, filenames in os.walk(abs):
+        package = (dist_name + dirpath[len(abs):]).replace("/", ".")
+        if "__init__.py" in filenames:
+            yield package
+            for filename in filenames:
+                if filename.endswith(".py") and filename != "__init__.py":
+                    yield ".".join([package, filename])[:-3]
+
+
+def import_all_modules(name=__name__, file=__file__):
+    for module in find_distribution_modules(name, file):
+        import_module(module)
+
+
+if os.environ.get("COVER_ALL_MODULES") or "--with-coverage3" in sys.argv:
+    import_all_modules()

+ 214 - 0
celery/tests/test_app.py

@@ -0,0 +1,214 @@
+import os
+
+from datetime import timedelta
+
+from celery import Celery
+from celery.app import defaults
+from celery.app.base import BaseApp
+from celery.loaders.base import BaseLoader
+from celery.utils.serialization import pickle
+
+from celery.tests import config
+from celery.tests.utils import unittest
+
+THIS_IS_A_KEY = "this is a value"
+
+
+class Object(object):
+
+    def __init__(self, **kwargs):
+        for key, value in kwargs.items():
+            setattr(self, key, value)
+
+
+def _get_test_config():
+    return dict((key, getattr(config, key))
+                    for key in dir(config)
+                        if key.isupper() and not key.startswith("_"))
+
+test_config = _get_test_config()
+
+
+class test_App(unittest.TestCase):
+
+    def setUp(self):
+        self.app = Celery()
+        self.app.conf.update(test_config)
+
+    def test_TaskSet(self):
+        ts = self.app.TaskSet()
+        self.assertListEqual(ts.tasks, [])
+        self.assertIs(ts.app, self.app)
+
+    def test_pickle_app(self):
+        changes = dict(THE_FOO_BAR="bars",
+                       THE_MII_MAR="jars")
+        self.app.conf.update(changes)
+        saved = pickle.dumps(self.app)
+        self.assertLess(len(saved), 2048)
+        restored = pickle.loads(saved)
+        self.assertDictContainsSubset(changes, restored.conf)
+
+    def test_worker_main(self):
+        from celery.bin import celeryd
+
+        class WorkerCommand(celeryd.WorkerCommand):
+
+            def execute_from_commandline(self, argv):
+                return argv
+
+        prev, celeryd.WorkerCommand = celeryd.WorkerCommand, WorkerCommand
+        try:
+            ret = self.app.worker_main(argv=["--version"])
+            self.assertListEqual(ret, ["--version"])
+        finally:
+            celeryd.WorkerCommand = prev
+
+    def test_config_from_envvar(self):
+        os.environ["CELERYTEST_CONFIG_OBJECT"] = "celery.tests.test_app"
+        self.app.config_from_envvar("CELERYTEST_CONFIG_OBJECT")
+        self.assertEqual(self.app.conf.THIS_IS_A_KEY, "this is a value")
+
+    def test_config_from_object(self):
+
+        class Object(object):
+            LEAVE_FOR_WORK = True
+            MOMENT_TO_STOP = True
+            CALL_ME_BACK = 123456789
+            WANT_ME_TO = False
+            UNDERSTAND_ME = True
+
+        self.app.config_from_object(Object())
+
+        self.assertTrue(self.app.conf.LEAVE_FOR_WORK)
+        self.assertTrue(self.app.conf.MOMENT_TO_STOP)
+        self.assertEqual(self.app.conf.CALL_ME_BACK, 123456789)
+        self.assertFalse(self.app.conf.WANT_ME_TO)
+        self.assertTrue(self.app.conf.UNDERSTAND_ME)
+
+    def test_config_from_cmdline(self):
+        cmdline = [".always_eager=no",
+                   ".result_backend=/dev/null",
+                   '.task_error_whitelist=(list)["a", "b", "c"]',
+                   "celeryd.prefetch_multiplier=368",
+                   ".foobarstring=(string)300",
+                   ".foobarint=(int)300",
+                   '.result_engine_options=(dict){"foo": "bar"}']
+        self.app.config_from_cmdline(cmdline, namespace="celery")
+        self.assertFalse(self.app.conf.CELERY_ALWAYS_EAGER)
+        self.assertEqual(self.app.conf.CELERY_RESULT_BACKEND, "/dev/null")
+        self.assertEqual(self.app.conf.CELERYD_PREFETCH_MULTIPLIER, 368)
+        self.assertListEqual(self.app.conf.CELERY_TASK_ERROR_WHITELIST,
+                             ["a", "b", "c"])
+        self.assertEqual(self.app.conf.CELERY_FOOBARSTRING, "300")
+        self.assertEqual(self.app.conf.CELERY_FOOBARINT, 300)
+        self.assertDictEqual(self.app.conf.CELERY_RESULT_ENGINE_OPTIONS,
+                             {"foo": "bar"})
+
+    def test_compat_setting_CELERY_BACKEND(self):
+
+        self.app.config_from_object(Object(CELERY_BACKEND="set_by_us"))
+        self.assertEqual(self.app.conf.CELERY_RESULT_BACKEND, "set_by_us")
+
+    def test_Windows_log_color_disabled(self):
+        self.app.IS_WINDOWS = True
+        self.app.config_from_object(Object(CELERYD_LOG_COLOR=True))
+        self.assertFalse(self.app.conf.CELERYD_LOG_COLOR)
+
+    def test_task_result_expires_converted_to_timedelta(self):
+        self.app.config_from_object(Object(CELERY_TASK_RESULT_EXPIRES=100))
+        self.assertEqual(self.app.conf.CELERY_TASK_RESULT_EXPIRES,
+                         timedelta(seconds=100))
+
+        self.assertIn("celery.backend_cleanup",
+                      self.app.conf.CELERYBEAT_SCHEDULE)
+
+    def test_backend_cleanup_not_installed(self):
+        self.app.config_from_object(Object(CELERY_TASK_RESULT_EXPIRES=None))
+        self.assertNotIn("celery.backend_cleanup",
+                         self.app.conf.CELERYBEAT_SCHEDULE)
+
+    def test_compat_setting_CARROT_BACKEND(self):
+        self.app.config_from_object(Object(CARROT_BACKEND="set_by_us"))
+        self.assertEqual(self.app.conf.BROKER_BACKEND, "set_by_us")
+
+    def test_mail_admins(self):
+
+        class Loader(BaseLoader):
+
+            def mail_admins(*args, **kwargs):
+                return args, kwargs
+
+        self.app.loader = Loader()
+        self.app.conf.ADMINS = None
+        self.assertFalse(self.app.mail_admins("Subject", "Body"))
+        self.app.conf.ADMINS = [("George Costanza", "george@vandelay.com")]
+        self.assertTrue(self.app.mail_admins("Subject", "Body"))
+
+    def test_amqp_get_broker_info(self):
+        self.assertDictContainsSubset({"hostname": "localhost",
+                                       "userid": "guest",
+                                       "password": "guest",
+                                       "virtual_host": "/"},
+                                      self.app.amqp.get_broker_info())
+        self.app.conf.BROKER_PORT = 1978
+        self.app.conf.BROKER_VHOST = "foo"
+        self.assertDictContainsSubset({"port": ":1978",
+                                       "virtual_host": "/foo"},
+                                      self.app.amqp.get_broker_info())
+        conn = self.app.broker_connection(virtual_host="/value")
+        self.assertDictContainsSubset({"virtual_host": "/value"},
+                                      self.app.amqp.get_broker_info(conn))
+
+    def test_send_task_sent_event(self):
+        from celery.app import amqp
+
+        class Dispatcher(object):
+            sent = []
+
+            def send(self, type, **fields):
+                self.sent.append((type, fields))
+
+        conn = self.app.broker_connection()
+        chan = conn.channel()
+        try:
+            for e in ("foo_exchange", "moo_exchange", "bar_exchange"):
+                chan.exchange_declare(e, "direct", durable=True)
+                chan.queue_declare(e, durable=True)
+                chan.queue_bind(e, e, e)
+        finally:
+            chan.close()
+        assert conn.transport_cls == "memory"
+
+        pub = self.app.amqp.TaskPublisher(conn, exchange="foo_exchange")
+        self.assertIn("foo_exchange", amqp._exchanges_declared)
+
+        dispatcher = Dispatcher()
+        self.assertTrue(pub.delay_task("footask", (), {},
+                                       exchange="moo_exchange",
+                                       routing_key="moo_exchange",
+                                       event_dispatcher=dispatcher))
+        self.assertTrue(dispatcher.sent)
+        self.assertEqual(dispatcher.sent[0][0], "task-sent")
+        self.assertTrue(pub.delay_task("footask", (), {},
+                                       event_dispatcher=dispatcher,
+                                       exchange="bar_exchange",
+                                       routing_key="bar_exchange"))
+        self.assertIn("bar_exchange", amqp._exchanges_declared)
+
+
+
+class test_BaseApp(unittest.TestCase):
+
+    def test_on_init(self):
+        app = BaseApp()
+
+
+class test_defaults(unittest.TestCase):
+
+    def test_str_to_bool(self):
+        for s in ("false", "no", "0"):
+            self.assertFalse(defaults.str_to_bool(s))
+        for s in ("true", "yes", "1"):
+            self.assertTrue(defaults.str_to_bool(s))
+        self.assertRaises(TypeError, defaults.str_to_bool, "unsure")

+ 10 - 5
celery/tests/test_task.py

@@ -240,7 +240,7 @@ class TestCeleryTasks(unittest.TestCase):
         self.assertEqual(task.ping(), 'pong')
 
     def assertNextTaskDataEqual(self, consumer, presult, task_name,
-            test_eta=False, **kwargs):
+            test_eta=False, test_expires=False, **kwargs):
         next_task = consumer.fetch()
         task_data = next_task.decode()
         self.assertEqual(task_data["id"], presult.task_id)
@@ -250,6 +250,10 @@ class TestCeleryTasks(unittest.TestCase):
             self.assertIsInstance(task_data.get("eta"), basestring)
             to_datetime = parse_iso8601(task_data.get("eta"))
             self.assertIsInstance(to_datetime, datetime)
+        if test_expires:
+            self.assertIsInstance(task_data.get("expires"), basestring)
+            to_datetime = parse_iso8601(task_data.get("expires"))
+            self.assertIsInstance(to_datetime, datetime)
         for arg_name, arg_value in kwargs.items():
             self.assertEqual(task_kwargs.get(arg_name), arg_value)
 
@@ -303,15 +307,16 @@ class TestCeleryTasks(unittest.TestCase):
 
         # With eta.
         presult2 = t1.apply_async(kwargs=dict(name="George Costanza"),
-                                  eta=datetime.now() + timedelta(days=1))
+                                  eta=datetime.now() + timedelta(days=1),
+                                  expires=datetime.now() + timedelta(days=2))
         self.assertNextTaskDataEqual(consumer, presult2, t1.name,
-                name="George Costanza", test_eta=True)
+                name="George Costanza", test_eta=True, test_expires=True)
 
         # With countdown.
         presult2 = t1.apply_async(kwargs=dict(name="George Costanza"),
-                                  countdown=10)
+                                  countdown=10, expires=12)
         self.assertNextTaskDataEqual(consumer, presult2, t1.name,
-                name="George Costanza", test_eta=True)
+                name="George Costanza", test_eta=True, test_expires=True)
 
         # Discarding all tasks.
         consumer.discard_all()

+ 9 - 1
setup.cfg

@@ -7,14 +7,19 @@ cover3-exclude = celery
                  celery.conf
                  celery.tests.*
                  celery.bin.celeryev
+                 celery.bin.celeryd_multi
+                 celery.bin.celeryd_detach
+                 celery.bin.celeryctl
+                 celery.bin.camqadm
                  celery.task
-                 celery.platform
+                 celery.platforms
                  celery.utils.patch
                  celery.utils.compat
                  celery.utils.mail
                  celery.utils.functional
                  celery.utils.dispatch*
                  celery.db.a805d4bd
+                 celery.db.dfd042c7
                  celery.contrib*
                  celery.concurrency.threads
                  celery.concurrency.processes.pool
@@ -22,6 +27,9 @@ cover3-exclude = celery
                  celery.backends.tyrant
                  celery.backends.pyredis
                  celery.backends.amqp
+                 celery.backends.cassandra
+                 celery.events.dumper
+                 celery.events.cursesmon
 
 [build_sphinx]
 source-dir = docs/