Browse Source

100% Coverage for celery.loaders.*

Ask Solem 14 years ago
parent
commit
02953e12d2
3 changed files with 166 additions and 15 deletions
  1. 3 2
      celery/loaders/app.py
  2. 12 6
      celery/loaders/base.py
  3. 151 7
      celery/tests/test_loaders.py

+ 3 - 2
celery/loaders/app.py

@@ -3,6 +3,7 @@ import os
 from celery.datastructures import DictAttribute
 from celery.exceptions import ImproperlyConfigured
 from celery.loaders.base import BaseLoader
+from celery.utils import get_cls_by_name
 
 
 ERROR_ENVVAR_NOT_SET = (
@@ -29,8 +30,8 @@ class AppLoader(BaseLoader):
     def config_from_object(self, obj, silent=False):
         if isinstance(obj, basestring):
             try:
-                obj = self.import_from_cwd(obj)
-            except ImportError:
+                obj = get_cls_by_name(obj, imp=self.import_from_cwd)
+            except (ImportError, AttributeError):
                 if silent:
                     return False
                 raise

+ 12 - 6
celery/loaders/base.py

@@ -113,19 +113,20 @@ class BaseLoader(object):
 
         return dict(map(getarg, args))
 
-    def mail_admins(self, subject, message, fail_silently=False,
+    def mail_admins(self, subject, body, fail_silently=False,
             sender=None, to=None, host=None, port=None,
             user=None, password=None, timeout=None):
-        from celery.utils import mail
         try:
-            message = mail.Message(sender=sender, to=to,
-                                    subject=subject, body=message)
-            mailer = mail.Mailer(host, port, user, password, timeout)
+            message = self.mail.Message(sender=sender, to=to,
+                                        subject=subject, body=body)
+            mailer = self.mail.Mailer(host=host, port=port,
+                                      user=user, password=password,
+                                      timeout=timeout)
             mailer.send(message)
         except Exception, exc:
             if not fail_silently:
                 raise
-            warnings.warn(mail.SendmailWarning(
+            warnings.warn(self.mail.SendmailWarning(
                 "Mail could not be sent: %r %r" % (
                     exc, {"To": to, "Subject": subject})))
 
@@ -133,3 +134,8 @@ class BaseLoader(object):
     def conf(self):
         """Loader configuration."""
         return self.read_configuration()
+
+    @cached_property
+    def mail(self):
+        from celery.utils import mail
+        return mail

+ 151 - 7
celery/tests/test_loaders.py

@@ -1,16 +1,62 @@
 import os
 import sys
-from celery.tests.utils import unittest
 
 from celery import task
 from celery import loaders
+from celery.app import app_or_default
+from celery.exceptions import ImproperlyConfigured
 from celery.loaders import base
 from celery.loaders import default
+from celery.loaders.app import AppLoader
 
 from celery.tests.compat import catch_warnings
+from celery.tests.utils import unittest
 from celery.tests.utils import with_environ, execute_context
 
 
+class ObjectConfig(object):
+    FOO = 1
+    BAR = 2
+
+object_config = ObjectConfig()
+dict_config = dict(FOO=10, BAR=20)
+
+
+class Object(object):
+
+    def __init__(self, **kwargs):
+        for k, v in kwargs.items():
+            setattr(self, k, v)
+
+
+class MockMail(object):
+
+    class SendmailWarning(UserWarning):
+        pass
+
+    class Message(Object):
+        pass
+
+    class Mailer(Object):
+        sent = []
+        raise_on_send = False
+
+        def send(self, message):
+            if self.__class__.raise_on_send:
+                raise KeyError("foo")
+            self.sent.append(message)
+
+
+class DummyLoader(base.BaseLoader):
+
+    def read_configuration(self):
+        return {"foo": "bar", "CELERY_IMPORTS": ("os", "sys")}
+
+    @property
+    def mail(self):
+        return MockMail()
+
+
 class TestLoaders(unittest.TestCase):
 
     def test_get_loader_cls(self):
@@ -18,18 +64,36 @@ class TestLoaders(unittest.TestCase):
         self.assertEqual(loaders.get_loader_cls("default"),
                           default.Loader)
 
+    def test_current_loader(self):
+        loader1 = loaders.current_loader()
+        loader2 = loaders.current_loader()
+        self.assertIs(loader1, loader2)
+        self.assertIs(loader2, loaders._loader)
+
+    def test_load_settings(self):
+        loader = loaders.current_loader()
+        loaders._settings = None
+        settings = loaders.load_settings()
+        self.assertTrue(loaders._settings)
+        settings = loaders.load_settings()
+        self.assertIs(settings, loaders._settings)
+        self.assertIs(settings, loader.conf)
+
     @with_environ("CELERY_LOADER", "default")
     def test_detect_loader_CELERY_LOADER(self):
         self.assertIsInstance(loaders.setup_loader(), default.Loader)
 
 
-class DummyLoader(base.BaseLoader):
-
-    def read_configuration(self):
-        return {"foo": "bar", "CELERY_IMPORTS": ("os", "sys")}
-
-
 class TestLoaderBase(unittest.TestCase):
+    message_options = {"subject": "Subject",
+                       "body": "Body",
+                       "sender": "x@x.com",
+                       "to": "y@x.com"}
+    server_options = {"host": "smtp.x.com",
+                      "port": 1234,
+                      "user": "x",
+                      "password": "qwerty",
+                      "timeout": 3}
 
     def setUp(self):
         self.loader = DummyLoader()
@@ -50,6 +114,50 @@ class TestLoaderBase(unittest.TestCase):
         self.assertItemsEqual(self.loader.import_default_modules(),
                               [os, sys, task])
 
+    def test_import_from_cwd_custom_imp(self):
+
+        def imp(module):
+            imp.called = True
+        imp.called = False
+
+        self.loader.import_from_cwd("foo", imp=imp)
+        self.assertTrue(imp.called)
+
+    def test_mail_admins_errors(self):
+        MockMail.Mailer.raise_on_send = True
+        opts = dict(self.message_options, **self.server_options)
+
+        def with_catch_warnings(log):
+            self.loader.mail_admins(fail_silently=True, **opts)
+            return log[0].message
+
+        warning = execute_context(catch_warnings(record=True),
+                                  with_catch_warnings)
+        self.assertIsInstance(warning, MockMail.SendmailWarning)
+        self.assertIn("KeyError", warning.args[0])
+
+        self.assertRaises(KeyError, self.loader.mail_admins,
+                          fail_silently=False, **opts)
+
+    def test_mail_admins(self):
+        MockMail.Mailer.raise_on_send = False
+        opts = dict(self.message_options, **self.server_options)
+
+        self.loader.mail_admins(**opts)
+        message = MockMail.Mailer.sent.pop()
+        self.assertDictContainsSubset(vars(message), self.message_options)
+
+    def test_mail_attribute(self):
+        from celery.utils import mail
+        loader = base.BaseLoader()
+        self.assertIs(loader.mail, mail)
+
+    def test_cmdline_config_ValueError(self):
+        self.assertRaises(ValueError, self.loader.cmdline_config_parser,
+                         ["broker.port=foobar"])
+
+
+
 
 class TestDefaultLoader(unittest.TestCase):
 
@@ -115,3 +223,39 @@ class TestDefaultLoader(unittest.TestCase):
         context = catch_warnings(record=True)
         execute_context(context, with_catch_warnings)
         self.assertTrue(context_executed[0])
+
+
+class test_AppLoader(unittest.TestCase):
+
+    def setUp(self):
+        self.app = app_or_default()
+        self.loader = AppLoader(app=self.app)
+
+    def test_config_from_envvar(self, key="CELERY_HARNESS_CFG1"):
+        self.assertFalse(self.loader.config_from_envvar("HDSAJIHWIQHEWQU",
+                                                        silent=True))
+        self.assertRaises(ImproperlyConfigured,
+                          self.loader.config_from_envvar, "HDSAJIHWIQHEWQU",
+                          silent=False)
+        os.environ[key] = __name__ + ".object_config"
+        self.assertTrue(self.loader.config_from_envvar(key))
+        self.assertEqual(self.loader.conf["FOO"], 1)
+        self.assertEqual(self.loader.conf["BAR"], 2)
+
+        os.environ[key] = "unknown_asdwqe.asdwqewqe"
+        self.assertRaises(ImportError,
+                          self.loader.config_from_envvar, key, silent=False)
+        self.assertFalse(self.loader.config_from_envvar(key, silent=True))
+
+        os.environ[key] = __name__ + ".dict_config"
+        self.assertTrue(self.loader.config_from_envvar(key))
+        self.assertEqual(self.loader.conf["FOO"], 10)
+        self.assertEqual(self.loader.conf["BAR"], 20)
+
+    def test_on_worker_init(self):
+        self.loader.conf["CELERY_IMPORTS"] = ("subprocess", )
+        sys.modules.pop("subprocess", None)
+        self.loader.on_worker_init()
+        self.assertIn("subprocess", sys.modules)
+
+