Parcourir la source

100% coverage for celery.app + celery.app.base

Ask Solem il y a 14 ans
Parent
commit
6fa5db52a0
1 fichiers modifiés avec 132 ajouts et 2 suppressions
  1. 132 2
      celery/tests/test_app/test_app.py

+ 132 - 2
celery/tests/test_app/test_app.py

@@ -1,13 +1,21 @@
+from __future__ import with_statement
+
 import os
+import sys
+
+from contextlib import contextmanager
+
+from mock import Mock
 
 from celery import Celery
+from celery import app as _app
 from celery.app import defaults
-from celery.app.base import BaseApp
+from celery.app.base import BaseApp, pyimplementation
 from celery.loaders.base import BaseLoader
 from celery.utils.serialization import pickle
 
 from celery.tests import config
-from celery.tests.utils import unittest
+from celery.tests.utils import unittest, mask_modules
 
 THIS_IS_A_KEY = "this is a value"
 
@@ -43,6 +51,9 @@ class test_App(unittest.TestCase):
         task = app.task(fun)
         self.assertEqual(task.name, app.main + ".fun")
 
+    def test_repr(self):
+        self.assertTrue(repr(self.app))
+
     def test_TaskSet(self):
         ts = self.app.TaskSet()
         self.assertListEqual(ts.tasks, [])
@@ -167,6 +178,35 @@ class test_App(unittest.TestCase):
         self.assertDictContainsSubset({"virtual_host": "/value"},
                                       conn.info())
 
+    def test_BROKER_BACKEND_alias(self):
+        self.assertEqual(self.app.conf.BROKER_BACKEND,
+                         self.app.conf.BROKER_TRANSPORT)
+
+    def test_with_default_connection(self):
+
+        @self.app.with_default_connection
+        def handler(connection=None, foo=None):
+            return connection, foo
+
+        connection, foo = handler(foo=42)
+        self.assertEqual(foo, 42)
+        self.assertTrue(connection)
+
+    def test_after_fork(self):
+        p = self.app._pool = Mock()
+        self.app._after_fork(self.app)
+        p.force_close_all.assert_called_with()
+        self.assertIsNone(self.app._pool)
+        self.app._after_fork(self.app)
+
+    def test_pool_no_multiprocessing(self):
+        with mask_modules("multiprocessing.util"):
+            pool = self.app.pool
+            self.assertIs(pool, self.app._pool)
+
+    def test_bugreport(self):
+        self.assertTrue(self.app.bugreport())
+
     def test_send_task_sent_event(self):
         from celery.app import amqp
 
@@ -218,3 +258,93 @@ class test_defaults(unittest.TestCase):
         for s in ("true", "yes", "1"):
             self.assertTrue(defaults.str_to_bool(s))
         self.assertRaises(TypeError, defaults.str_to_bool, "unsure")
+
+
+
+class test_debugging_utils(unittest.TestCase):
+
+    def test_enable_disable_trace(self):
+        try:
+            _app.enable_trace()
+            self.assertEqual(_app.app_or_default, _app._app_or_default_trace)
+            _app.disable_trace()
+            self.assertEqual(_app.app_or_default, _app._app_or_default)
+        finally:
+            _app.disable_trace()
+
+
+class test_compilation(unittest.TestCase):
+    _clean = ("celery.app.base", )
+
+    def setUp(self):
+        self._prev = dict((k, sys.modules.pop(k, None)) for k in self._clean)
+
+    def tearDown(self):
+        sys.modules.update(self._prev)
+
+    def test_kombu_version_check(self):
+        import kombu
+        kombu.VERSION = (0, 9, 9)
+        with self.assertRaises(ImportError):
+            __import__("celery.app.base")
+
+
+class test_pyimplementation(unittest.TestCase):
+
+    @contextmanager
+    def platform_pyimp(self, replace=None):
+        import platform
+        prev = getattr(platform, "python_implementation", None)
+        if replace:
+            platform.python_implementation = replace
+        else:
+            try:
+                delattr(platform, "python_implementation")
+            except AttributeError:
+                pass
+        yield
+        if prev is not None:
+            platform.python_implementation = prev
+
+    @contextmanager
+    def sys_platform(self, value):
+        prev, sys.platform = sys.platform, value
+        yield
+        sys.platform = prev
+
+    @contextmanager
+    def pypy_version(self, value=None):
+        prev = getattr(sys, "pypy_version_info", None)
+        if value:
+            sys.pypy_version_info = value
+        else:
+            try:
+                delattr(sys, "pypy_version_info")
+            except AttributeError:
+                pass
+        yield
+        if prev is not None:
+            sys.pypy_version_info = prev
+
+    def test_platform_python_implementation(self):
+        with self.platform_pyimp(lambda: "Xython"):
+            self.assertEqual(pyimplementation(), "Xython")
+
+    def test_platform_jython(self):
+        with self.platform_pyimp():
+            with self.sys_platform("java 1.6.51"):
+                self.assertIn("Jython", pyimplementation())
+
+    def test_platform_pypy(self):
+        with self.platform_pyimp():
+            with self.sys_platform("darwin"):
+                with self.pypy_version((1, 4, 3)):
+                    self.assertIn("PyPy", pyimplementation())
+                with self.pypy_version((1, 4, 3, "a4")):
+                    self.assertIn("PyPy", pyimplementation())
+
+    def test_platform_fallback(self):
+        with self.platform_pyimp():
+            with self.sys_platform("darwin"):
+                with self.pypy_version():
+                    self.assertEqual("CPython", pyimplementation())