Browse Source

More tests

Ask Solem 13 years ago
parent
commit
f51114dd9a

+ 15 - 3
celery/__compat__.py

@@ -15,7 +15,7 @@ DEFAULT_ATTRS = set(["__file__", "__path__", "__doc__", "__all__"])
 
 # im_func is no longer available in Py3.
 # instead the unbound method itself can be used.
-if sys.version_info[0] == 3:
+if sys.version_info[0] == 3:  # pragma: no cover
     def fun_of_method(method):
         return method
 else:
@@ -69,6 +69,17 @@ COMPAT_MODULES = {
             "tasks": "tasks",
         },
     },
+    "celery.task": {
+        "control": {
+            "broadcast": "control.broadcast",
+            "rate_limit": "control.rate_limit",
+            "time_limit": "control.time_limit",
+            "ping": "control.ping",
+            "revoke": "control.revoke",
+            "discard_all": "control.discard_all",
+            "inspect": "control.inspect",
+        }
+    }
 }
 
 
@@ -158,8 +169,9 @@ def get_compat_module(pkg, name):
             return Proxy(getappattr, (attr.split('.'), ))
         return attr
 
-    return create_module(name, COMPAT_MODULES[pkg.__name__][name],
-                         pkg=pkg, prepare_attr=prepare)
+    attrs = dict(COMPAT_MODULES[pkg.__name__][name])
+    attrs["__all__"] = attrs.keys()
+    return create_module(name, attrs, pkg=pkg, prepare_attr=prepare)
 
 
 def get_origins(defs):

+ 0 - 1
celery/__init__.py

@@ -17,7 +17,6 @@ __docformat__ = "restructuredtext"
 # Lazy loading
 from .__compat__ import recreate_module
 
-
 old_module, new_module = recreate_module(__name__,
     by_module={
         "celery.app":       ["Celery", "bugreport"],

+ 1 - 1
celery/app/__init__.py

@@ -43,7 +43,7 @@ set_default_app(Celery("default", loader=default_loader,
 
 
 def bugreport():
-    return current_app.bugreport()
+    return current_app().bugreport()
 
 
 def _app_or_default(app=None):

+ 4 - 2
celery/app/amqp.py

@@ -16,7 +16,7 @@ from datetime import timedelta
 from kombu import BrokerConnection, Exchange
 from kombu import compat as messaging
 from kombu import pools
-from kombu.common import maybe_declare
+from kombu.common import declaration_cached, maybe_declare
 
 from celery import signals
 from celery.utils import cached_property, lpmerge, uuid
@@ -35,6 +35,7 @@ QUEUE_FORMAT = """
 binding:%(binding_key)s
 """
 
+
 def extract_msg_options(options, keep=MSG_OPTIONS):
     """Extracts known options to `basic_publish` from a dict,
     and returns a new dict."""
@@ -157,7 +158,8 @@ class TaskPublisher(messaging.Publisher):
         super(TaskPublisher, self).__init__(*args, **kwargs)
 
     def declare(self):
-        if self.exchange.name and not declaration_cached(self.exchange):
+        if self.exchange.name and \
+                not declaration_cached(self.exchange, self.channel):
             super(TaskPublisher, self).declare()
 
     def _get_queue(self, name):

+ 8 - 6
celery/app/base.py

@@ -19,13 +19,14 @@ from contextlib import contextmanager
 from copy import deepcopy
 from functools import wraps
 
+from billiard.util import register_after_fork
 from kombu.clocks import LamportClock
+from kombu.utils import cached_property
 
 from celery import platforms
 from celery.exceptions import AlwaysEagerIgnored
 from celery.loaders import get_loader_cls
 from celery.local import PromiseProxy, maybe_evaluate
-from celery.utils import cached_property, register_after_fork
 from celery.utils.functional import first
 from celery.utils.imports import instantiate, symbol_by_name
 
@@ -159,7 +160,7 @@ class Celery(object):
             eta=None, task_id=None, publisher=None, connection=None,
             connect_timeout=None, result_cls=None, expires=None,
             queues=None, **options):
-        if self.conf.CELERY_ALWAYS_EAGER:
+        if self.conf.CELERY_ALWAYS_EAGER:  # pragma: no cover
             warnings.warn(AlwaysEagerIgnored(
                 "CELERY_ALWAYS_EAGER has no effect on send_task"))
 
@@ -234,9 +235,6 @@ class Celery(object):
 
     def prepare_config(self, c):
         """Prepare configuration before it is merged with the defaults."""
-        if self._preconf:
-            for key, value in self._preconf.iteritems():
-                setattr(c, key, value)
         return find_deprecated_settings(c)
 
     def now(self):
@@ -275,8 +273,12 @@ class Celery(object):
         return backend(app=self, url=url)
 
     def _get_config(self):
-        return Settings({}, [self.prepare_config(self.loader.conf),
+        s = Settings({}, [self.prepare_config(self.loader.conf),
                              deepcopy(DEFAULTS)])
+        if self._preconf:
+            for key, value in self._preconf.iteritems():
+                setattr(s, key, value)
+        return s
 
     def _after_fork(self, obj_):
         if self._pool:

+ 18 - 11
celery/app/builtins.py

@@ -39,7 +39,11 @@ def add_backend_cleanup_task(app):
     may even clean up in realtime so that a periodic cleanup is not necessary.
 
     """
-    return app.task(name="celery.backend_cleanup")(app.backend.cleanup)
+
+    @app.task(name="celery.backend_cleanup")
+    def backend_cleanup():
+        app.backend.cleanup()
+    return backend_cleanup
 
 
 @builtin_task
@@ -79,17 +83,19 @@ def add_group_task(app):
         def run(self, tasks, result):
             app = self.app
             result = from_serializable(result)
+            if self.request.is_eager or app.conf.CELERY_ALWAYS_EAGER:
+                return app.TaskSetResult(result.id,
+                        [subtask(task).apply(taskset_id=self.request.taskset)
+                            for task in tasks])
             with app.pool.acquire(block=True) as conn:
                 with app.amqp.TaskPublisher(conn) as publisher:
-                    res_ = [subtask(task).apply_async(
-                                        taskset_id=self.request.taskset,
-                                        publisher=publisher)
-                                for task in tasks]
+                    [subtask(task).apply_async(
+                                    taskset_id=self.request.taskset,
+                                    publisher=publisher)
+                            for task in tasks]
             parent = get_current_task()
             if parent:
                 parent.request.children.append(result)
-            if self.request.is_eager or app.conf.CELERY_ALWAYS_EAGER:
-                return app.TaskSetResult(result.id, res_)
             return result
 
         def prepare(self, options, tasks, **kwargs):
@@ -111,8 +117,7 @@ def add_group_task(app):
 
         def apply(self, args=(), kwargs={}, **options):
             tasks, result = self.prepare(options, **kwargs)
-            return super(Group, self).apply((tasks, result), {"eager": True},
-                                            **options)
+            return super(Group, self).apply((tasks, result), **options)
 
     return Group
 
@@ -177,8 +182,10 @@ def add_chord_task(app):
             body = maybe_subtask(kwargs["body"])
 
             callback_id = body.options.setdefault("task_id", task_id or uuid())
-            super(Chord, self).apply_async(args, kwargs, **options)
-            return self.AsyncResult(callback_id)
+            parent = super(Chord, self).apply_async(args, kwargs, **options)
+            body_result = self.AsyncResult(callback_id)
+            body_result.parent = parent
+            return body_result
 
         def apply(self, args=(), kwargs={}, **options):
             body = kwargs["body"]

+ 1 - 0
celery/apps/worker.py

@@ -279,6 +279,7 @@ install_worker_term_hard_handler = partial(
     _shutdown_handler, sig="SIGQUIT", how="terminate", exc=SystemTerminate,
 )
 
+
 def on_SIGINT(worker):
     print("celeryd: Hitting Ctrl+C again will terminate all running tasks!")
     install_worker_term_hard_handler(worker, sig="SIGINT")

+ 2 - 1
celery/beat.py

@@ -448,7 +448,8 @@ class _Threaded(threading.Thread):
 
 supports_fork = True
 try:
-    import _multiprocessing
+    from billiard._ext import _billiard
+    supports_fork = True if _billiard else False
 except ImportError:
     supports_fork = False
 

+ 10 - 7
celery/canvas.py

@@ -162,10 +162,12 @@ class Signature(dict):
 
     def __or__(self, other):
         if isinstance(other, chain):
-            return chain(self.tasks + other.tasks)
+            return chain(*self.tasks + other.tasks)
         elif isinstance(other, Signature):
+            if isinstance(self, chain):
+                return chain(*self.tasks + (other, ))
             return chain(self, other)
-        return NotImplementedError
+        return NotImplemented
 
     def __invert__(self):
         return self.apply_async().get()
@@ -241,7 +243,8 @@ class chord(Signature):
     @classmethod
     def from_dict(self, d):
         kwargs = d["kwargs"]
-        return chord(kwargs["header"], kwargs["body"], **kwdict(d["options"]))
+        return chord(kwargs["header"], kwargs.get("body"),
+                     **kwdict(d["options"]))
 
     def __call__(self, body=None, **options):
         _chord = self.Chord
@@ -254,10 +257,10 @@ class chord(Signature):
 
     def clone(self, *args, **kwargs):
         s = Signature.clone(self, *args, **kwargs)
-        # need make copy of body
+        # need to make copy of body
         try:
-            kwargs["body"] = kwargs["body"].clone()
-        except KeyError:
+            s.kwargs["body"] = s.kwargs["body"].clone()
+        except (AttributeError, KeyError):
             pass
         return s
 
@@ -280,7 +283,7 @@ class chord(Signature):
 
     @property
     def body(self):
-        return self.kwargs["body"]
+        return self.kwargs.get("body")
 Signature.register_type(chord)
 
 

+ 0 - 2
celery/concurrency/eventlet.py

@@ -8,8 +8,6 @@ if not os.environ.get("EVENTLET_NOPATCH"):
     eventlet.monkey_patch()
     eventlet.debug.hub_prevent_multiple_readers(False)
 
-import sys
-
 from time import time
 
 from celery import signals

+ 0 - 2
celery/concurrency/gevent.py

@@ -6,8 +6,6 @@ if not os.environ.get("GEVENT_NOPATCH"):
     from gevent import monkey
     monkey.patch_all()
 
-import sys
-
 from time import time
 
 from celery.utils import timer2

+ 1 - 0
celery/loaders/__init__.py

@@ -19,6 +19,7 @@ LOADER_ALIASES = {"app": "celery.loaders.app:AppLoader",
                   "default": "celery.loaders.default:Loader",
                   "django": "djcelery.loaders:DjangoLoader"}
 
+
 def get_loader_cls(loader):
     """Get loader class by name/alias"""
     return symbol_by_name(loader, LOADER_ALIASES)

+ 10 - 9
celery/local.py

@@ -34,12 +34,13 @@ class Proxy(object):
         object.__setattr__(self, '_Proxy__local', local)
         object.__setattr__(self, '_Proxy__args', args or ())
         object.__setattr__(self, '_Proxy__kwargs', kwargs or {})
-        object.__setattr__(self, '__custom_name__', name)
+        if name is not None:
+            object.__setattr__(self, '__custom_name__', name)
 
     @property
     def __name__(self):
         try:
-            return object.__getattr__(self, "__custom_name__")
+            return self.__custom_name__
         except AttributeError:
             return self._get_current_object().__name__
 
@@ -67,32 +68,32 @@ class Proxy(object):
     def __dict__(self):
         try:
             return self._get_current_object().__dict__
-        except RuntimeError:
+        except RuntimeError:  # pragma: no cover
             raise AttributeError('__dict__')
 
     def __repr__(self):
         try:
             obj = self._get_current_object()
-        except RuntimeError:
+        except RuntimeError:  # pragma: no cover
             return '<%s unbound>' % self.__class__.__name__
         return repr(obj)
 
     def __nonzero__(self):
         try:
             return bool(self._get_current_object())
-        except RuntimeError:
+        except RuntimeError:  # pragma: no cover
             return False
 
     def __unicode__(self):
         try:
             return unicode(self._get_current_object())
-        except RuntimeError:
+        except RuntimeError:  # pragma: no cover
             return repr(self)
 
     def __dir__(self):
         try:
             return dir(self._get_current_object())
-        except RuntimeError:
+        except RuntimeError:  # pragma: no cover
             return []
 
     def __getattr__(self, name):
@@ -155,8 +156,8 @@ class Proxy(object):
     __hex__ = lambda x: hex(x._get_current_object())
     __index__ = lambda x: x._get_current_object().__index__()
     __coerce__ = lambda x, o: x.__coerce__(x, o)
-    __enter__ = lambda x: x.__enter__()
-    __exit__ = lambda x, *a, **kw: x.__exit__(*a, **kw)
+    __enter__ = lambda x: x._get_current_object().__enter__()
+    __exit__ = lambda x, *a, **kw: x._get_current_object().__exit__(*a, **kw)
     __reduce__ = lambda x: x._get_current_object().__reduce__()
 
 

+ 1 - 0
celery/security/serialization.py

@@ -1,4 +1,5 @@
 from __future__ import absolute_import
+from __future__ import with_statement
 
 import base64
 

+ 2 - 2
celery/security/utils.py

@@ -8,8 +8,8 @@ from celery.exceptions import SecurityError
 
 try:
     from OpenSSL import crypto
-except ImportError:
-    crypto = None  # noqa
+except ImportError:  # pragma: no cover
+    crypto = None    # noqa
 
 
 @contextmanager

+ 0 - 13
celery/task/control.py

@@ -1,13 +0,0 @@
-from __future__ import absolute_import
-
-from celery import current_app
-from celery.local import Proxy
-
-
-broadcast = Proxy(lambda: current_app.control.broadcast)
-rate_limit = Proxy(lambda: current_app.control.rate_limit)
-time_limit = Proxy(lambda: current_app.control.time_limit)
-ping = Proxy(lambda: current_app.control.ping)
-revoke = Proxy(lambda: current_app.control.revoke)
-discard_all = Proxy(lambda: current_app.control.discard_all)
-inspect = Proxy(lambda: current_app.control.inspect)

+ 32 - 2
celery/tests/test_app/__init__.py

@@ -3,11 +3,13 @@ from __future__ import with_statement
 
 import os
 
-from mock import Mock
+from mock import Mock, patch
+from pickle import loads, dumps
 
 from celery import Celery
 from celery import app as _app
 from celery.app import defaults
+from celery.app import state
 from celery.loaders.base import BaseLoader
 from celery.platforms import pyimplementation
 from celery.utils.serialization import pickle
@@ -36,6 +38,15 @@ def _get_test_config():
 test_config = _get_test_config()
 
 
+class test_module(Case):
+
+    def test_default_app(self):
+        self.assertEqual(_app.default_app, state.default_app)
+
+    def test_bugreport(self):
+        self.assertTrue(_app.bugreport())
+
+
 class test_App(Case):
 
     def setUp(self):
@@ -52,6 +63,10 @@ class test_App(Case):
         task = app.task(fun)
         self.assertEqual(task.name, app.main + ".fun")
 
+    def test_with_broker(self):
+        app = Celery(set_as_current=False, broker="foo://baribaz")
+        self.assertEqual(app.conf.BROKER_HOST, "foo://baribaz")
+
     def test_repr(self):
         self.assertTrue(repr(self.app))
 
@@ -148,6 +163,22 @@ class test_App(Case):
         self.app.config_from_object(Object(CARROT_BACKEND="set_by_us"))
         self.assertEqual(self.app.conf.BROKER_TRANSPORT, "set_by_us")
 
+    def test_WorkController(self):
+        x = self.app.Worker()
+        self.assertIs(x.app, self.app)
+
+    def test_AsyncResult(self):
+        x = self.app.AsyncResult("1")
+        self.assertIs(x.app, self.app)
+        r = loads(dumps(x))
+        # not set as current, so ends up as default app after reduce
+        self.assertIs(r.app, state.default_app)
+
+    @patch("celery.bin.celery.CeleryCommand.execute_from_commandline")
+    def test_start(self, execute):
+        self.app.start()
+        self.assertTrue(execute.called)
+
     def test_mail_admins(self):
 
         class Loader(BaseLoader):
@@ -208,7 +239,6 @@ class test_App(Case):
         self.assertTrue(self.app.bugreport())
 
     def test_send_task_sent_event(self):
-        from celery.app import amqp
 
         class Dispatcher(object):
             sent = []

+ 1 - 1
celery/tests/test_app/test_beat.py

@@ -306,7 +306,7 @@ class test_EmbeddedService(Case):
 
     def test_start_stop_process(self):
         try:
-            import _multiprocessing
+            import _multiprocessing  # noqa
         except ImportError:
             raise SkipTest("multiprocessing not available")
 

+ 113 - 0
celery/tests/test_app/test_builtins.py

@@ -0,0 +1,113 @@
+from __future__ import absolute_import
+
+from mock import Mock
+
+from celery import current_app as app, group, task, chord
+from celery.app import builtins
+from celery.app.state import _tls
+from celery.tests.utils import Case
+
+
+@task
+def add(x, y):
+    return x + y
+
+
+@task
+def xsum(x):
+    return sum(x)
+
+
+class test_backend_cleanup(Case):
+
+    def test_run(self):
+        prev = app.backend
+        app.backend.cleanup = Mock()
+        app.backend.cleanup.__name__ = "cleanup"
+        try:
+            cleanup_task = builtins.add_backend_cleanup_task(app)
+            cleanup_task()
+            self.assertTrue(app.backend.cleanup.called)
+        finally:
+            app.backend = prev
+
+
+class test_group(Case):
+
+    def setUp(self):
+        self.prev = app.tasks.get("celery.group")
+        self.task = builtins.add_group_task(app)()
+
+    def tearDown(self):
+        app.tasks["celery.group"] = self.prev
+
+    def test_apply_async_eager(self):
+        self.task.apply = Mock()
+        app.conf.CELERY_ALWAYS_EAGER = True
+        try:
+            self.task.apply_async()
+        finally:
+            app.conf.CELERY_ALWAYS_EAGER = False
+        self.assertTrue(self.task.apply.called)
+
+    def test_apply(self):
+        x = group([add.s(4, 4), add.s(8, 8)])
+        x.name = self.task.name
+        res = x.apply()
+        self.assertEqual(res.get().join(), [8, 16])
+
+    def test_apply_async(self):
+        x = group([add.s(4, 4), add.s(8, 8)])
+        x.apply_async()
+
+    def test_apply_async_with_parent(self):
+        _tls.current_task = add
+        try:
+            x = group([add.s(4, 4), add.s(8, 8)])
+            x.apply_async()
+            self.assertTrue(add.request.children)
+        finally:
+            _tls.current_task = None
+
+
+class test_chain(Case):
+
+    def setUp(self):
+        self.prev = app.tasks.get("celery.chain")
+        self.task = builtins.add_chain_task(app)()
+
+    def tearDown(self):
+        app.tasks["celery.chain"] = self.prev
+
+    def test_apply_async(self):
+        c = add.s(2, 2) | add.s(4) | add.s(8)
+        result = c.apply_async()
+        self.assertTrue(result.parent)
+        self.assertTrue(result.parent.parent)
+        self.assertIsNone(result.parent.parent.parent)
+
+
+class test_chord(Case):
+
+    def setUp(self):
+        self.prev = app.tasks.get("celery.chord")
+        self.task = builtins.add_chain_task(app)()
+
+    def tearDown(self):
+        app.tasks["celery.chord"] = self.prev
+
+    def test_apply_async(self):
+        x = chord([add.s(i, i) for i in xrange(10)], body=xsum.s())
+        r = x.apply_async()
+        self.assertTrue(r)
+        self.assertTrue(r.parent)
+
+    def test_apply_eager(self):
+        app.conf.CELERY_ALWAYS_EAGER = True
+        try:
+            x = chord([add.s(i, i) for i in xrange(10)], body=xsum.s())
+            r = x.apply_async()
+            self.assertEqual(r.get(), 90)
+
+        finally:
+            app.conf.CELERY_ALWAYS_EAGER = False

+ 28 - 0
celery/tests/test_app/test_loaders.py

@@ -4,6 +4,8 @@ from __future__ import with_statement
 import os
 import sys
 
+from mock import patch
+
 from celery import loaders
 from celery.app import app_or_default
 from celery.exceptions import (
@@ -13,6 +15,7 @@ from celery.exceptions import (
 from celery.loaders import base
 from celery.loaders import default
 from celery.loaders.app import AppLoader
+from celery.utils.imports import NotAPackage
 
 from celery.tests.utils import AppCase, Case
 from celery.tests.compat import catch_warnings
@@ -158,6 +161,31 @@ class TestDefaultLoader(Case):
         self.assertFalse(l.wanted_module_item("__FOO"))
         self.assertFalse(l.wanted_module_item("foo"))
 
+    @patch("celery.loaders.default.find_module")
+    def test_read_configuration_not_a_package(self, find_module):
+        find_module.side_effect = NotAPackage()
+        l = default.Loader()
+        with self.assertRaises(NotAPackage):
+            l.read_configuration()
+
+    @patch("celery.loaders.default.find_module")
+    def test_read_configuration_py_in_name(self, find_module):
+        prev = os.environ["CELERY_CONFIG_MODULE"]
+        os.environ["CELERY_CONFIG_MODULE"] = "celeryconfig.py"
+        try:
+            find_module.side_effect = NotAPackage()
+            l = default.Loader()
+            with self.assertRaises(NotAPackage):
+                l.read_configuration()
+        finally:
+            os.environ["CELERY_CONFIG_MODULE"] = prev
+
+    @patch("celery.loaders.default.find_module")
+    def test_read_configuration_importerror(self, find_module):
+        find_module.side_effect = ImportError()
+        l = default.Loader()
+        l.read_configuration()
+
     def test_read_configuration(self):
         from types import ModuleType
 

+ 73 - 2
celery/tests/test_app/test_log.py

@@ -5,11 +5,13 @@ import sys
 import logging
 from tempfile import mktemp
 
+from mock import patch, Mock
+
 from celery import current_app
 from celery.app.log import Logging
 from celery.utils.log import LoggingProxy
 from celery.utils import uuid
-from celery.utils.log import get_logger
+from celery.utils.log import get_logger, ColorFormatter, logger as base_logger
 from celery.tests.utils import (
     Case, override_stdouts, wrap_logger, get_handlers,
 )
@@ -17,6 +19,58 @@ from celery.tests.utils import (
 log = current_app.log
 
 
+class test_ColorFormatter(Case):
+
+    @patch("celery.utils.log.safe_str")
+    @patch("logging.Formatter.formatException")
+    def test_formatException_not_string(self, fe, safe_str):
+        x = ColorFormatter("HELLO")
+        value = KeyError()
+        fe.return_value = value
+        self.assertIs(x.formatException(value), value)
+        self.assertTrue(fe.called)
+        self.assertFalse(safe_str.called)
+
+    @patch("logging.Formatter.formatException")
+    @patch("celery.utils.log.safe_str")
+    def test_formatException_string(self, safe_str, fe, value="HELLO"):
+        x = ColorFormatter(value)
+        fe.return_value = value
+        self.assertTrue(x.formatException(value))
+        self.assertTrue(safe_str.called)
+
+    @patch("celery.utils.log.safe_str")
+    def test_format_raises(self, safe_str):
+        x = ColorFormatter("HELLO")
+
+        def on_safe_str(s):
+            try:
+                raise ValueError("foo")
+            finally:
+                safe_str.side_effect = None
+        safe_str.side_effect = on_safe_str
+
+        record = Mock()
+        record.levelname = "ERROR"
+        record.msg = "HELLO"
+        record.exc_text = "error text"
+        safe_str.return_value = record
+
+        x.format(record)
+        self.assertIn("<Unrepresentable", record.msg)
+        self.assertEqual(safe_str.call_count, 2)
+
+    @patch("celery.utils.log.safe_str")
+    def test_format_raises_no_color(self, safe_str):
+        x = ColorFormatter("HELLO", False)
+        record = Mock()
+        record.levelname = "ERROR"
+        record.msg = "HELLO"
+        record.exc_text = "error text"
+        x.format(record)
+        self.assertEqual(safe_str.call_count, 1)
+
+
 class test_default_logger(Case):
 
     def setUp(self):
@@ -24,6 +78,14 @@ class test_default_logger(Case):
         self.get_logger = lambda n=None: get_logger(n) if n else logging.root
         Logging._setup = False
 
+    def test_get_logger_sets_parent(self):
+        logger = get_logger("celery.test_get_logger")
+        self.assertEqual(logger.parent.name, base_logger.name)
+
+    def test_get_logger_root(self):
+        logger = get_logger(base_logger.name)
+        self.assertIs(logger.parent, logging.root)
+
     def test_setup_logging_subsystem_colorize(self):
         log.setup_logging_subsystem(colorize=None)
         log.setup_logging_subsystem(colorize=True)
@@ -57,7 +119,6 @@ class test_default_logger(Case):
         Logging._setup = False
         logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
                                    root=False, colorize=None)
-        print(logger.handlers)
         self.assertIs(get_handlers(logger)[0].stream, sys.__stderr__,
                 "setup_logger logs to stderr without logfile argument.")
 
@@ -112,6 +173,16 @@ class test_default_logger(Case):
             self.assertFalse(p.isatty())
             self.assertIsNone(p.fileno())
 
+    def test_logging_proxy_recurse_protection(self):
+        logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
+                root=False)
+        p = LoggingProxy(logger, loglevel=logging.ERROR)
+        p._thread.recurse_protection = True
+        try:
+            self.assertIsNone(p.write("FOOFO"))
+        finally:
+            p._thread.recurse_protection = False
+
 
 class test_task_logger(test_default_logger):
 

+ 5 - 10
celery/tests/test_bin/test_celeryd.py

@@ -19,15 +19,10 @@ from celery import current_app
 from celery.apps import worker as cd
 from celery.bin.celeryd import WorkerCommand, main as celeryd_main
 from celery.exceptions import ImproperlyConfigured, SystemTerminate
+from celery.utils.log import ensure_process_aware_logger
 
-from celery.tests.utils import (
-    AppCase, WhateverIO, mask_modules,
-    reset_modules, skip_unless_module,
-    create_pidlock,
-)
-
+from celery.tests.utils import AppCase, WhateverIO, create_pidlock
 
-from celery.utils.log import ensure_process_aware_logger
 ensure_process_aware_logger()
 
 
@@ -416,7 +411,7 @@ class test_signal_handlers(AppCase):
     @disable_stdouts
     def test_worker_int_handler_only_stop_MainProcess(self):
         try:
-            import _multiprocessing
+            import _multiprocessing  # noqa
         except ImportError:
             raise SkipTest("only relevant for multiprocessing")
         process = current_process()
@@ -439,7 +434,7 @@ class test_signal_handlers(AppCase):
     @disable_stdouts
     def test_worker_term_hard_handler_only_stop_MainProcess(self):
         try:
-            import _multiprocessing
+            import _multiprocessing  # noqa
         except ImportError:
             raise SkipTest("only relevant for multiprocessing")
         process = current_process()
@@ -477,7 +472,7 @@ class test_signal_handlers(AppCase):
     @disable_stdouts
     def test_worker_term_handler_only_stop_MainProcess(self):
         try:
-            import _multiprocessing
+            import _multiprocessing  # noqa
         except ImportError:
             raise SkipTest("only relevant for multiprocessing")
         process = current_process()

+ 158 - 0
celery/tests/test_task/test_canvas.py

@@ -0,0 +1,158 @@
+from __future__ import absolute_import
+from __future__ import with_statement
+
+from mock import Mock
+
+from celery import task
+from celery.canvas import Signature, chain, group, chord, subtask
+
+from celery.tests.utils import Case
+
+SIG = Signature({"task": "TASK",
+                 "args": ("A1", ),
+                 "kwargs": {"K1": "V1"},
+                 "options": {"task_id": "TASK_ID"},
+                 "subtask_type": ""})
+
+
+@task
+def add(x, y):
+    return x + y
+
+
+@task
+def mul(x, y):
+    return x * y
+
+
+@task
+def div(x, y):
+    return x / y
+
+
+class test_Signature(Case):
+
+    def test_getitem_property_class(self):
+        self.assertTrue(Signature.task)
+        self.assertTrue(Signature.args)
+        self.assertTrue(Signature.kwargs)
+        self.assertTrue(Signature.options)
+        self.assertTrue(Signature.subtask_type)
+
+    def test_getitem_property(self):
+        self.assertEqual(SIG.task, "TASK")
+        self.assertEqual(SIG.args, ("A1", ))
+        self.assertEqual(SIG.kwargs, {"K1": "V1"})
+        self.assertEqual(SIG.options, {"task_id": "TASK_ID"})
+        self.assertEqual(SIG.subtask_type, "")
+
+    def test_replace(self):
+        x = Signature("TASK", ("A"), {})
+        self.assertTupleEqual(x.replace(args=("B", )).args, ("B", ))
+        self.assertDictEqual(x.replace(kwargs={"FOO": "BAR"}).kwargs,
+                {"FOO": "BAR"})
+        self.assertDictEqual(x.replace(options={"task_id": "123"}).options,
+                {"task_id": "123"})
+
+    def test_set(self):
+        self.assertDictEqual(Signature("TASK", x=1).set(task_id="2").options,
+                {"x": 1, "task_id": "2"})
+
+    def test_link(self):
+        x = subtask(SIG)
+        x.link(SIG)
+        x.link(SIG)
+        self.assertIn(SIG, x.options["link"])
+        self.assertEqual(len(x.options["link"]), 1)
+
+    def test_link_error(self):
+        x = subtask(SIG)
+        x.link_error(SIG)
+        x.link_error(SIG)
+        self.assertIn(SIG, x.options["link_error"])
+        self.assertEqual(len(x.options["link_error"]), 1)
+
+    def test_flatten_links(self):
+        tasks = [add.s(2, 2), mul.s(4), div.s(2)]
+        tasks[0].link(tasks[1])
+        tasks[1].link(tasks[2])
+        self.assertEqual(tasks[0].flatten_links(), tasks)
+
+    def test_OR(self):
+        x = add.s(2, 2) | mul.s(4)
+        self.assertIsInstance(x, chain)
+        y = add.s(4, 4) | div.s(2)
+        z = x | y
+        self.assertIsInstance(y, chain)
+        self.assertIsInstance(z, chain)
+        self.assertEqual(len(z.tasks), 4)
+        with self.assertRaises(TypeError):
+            x | 10
+
+    def test_INVERT(self):
+        x = add.s(2, 2)
+        x.apply_async = Mock()
+        x.apply_async.return_value = Mock()
+        x.apply_async.return_value.get = Mock()
+        x.apply_async.return_value.get.return_value = 4
+        self.assertEqual(~x, 4)
+        self.assertTrue(x.apply_async.called)
+
+
+class test_chain(Case):
+
+    def test_repr(self):
+        x = add.s(2, 2) | add.s(2)
+        self.assertEqual(repr(x), '%s(2, 2) | %s(2)' % (add.name, add.name))
+
+    def test_reverse(self):
+        x = add.s(2, 2) | add.s(2)
+        self.assertIsInstance(subtask(x), chain)
+        self.assertIsInstance(subtask(dict(x)), chain)
+
+
+class test_group(Case):
+
+    def test_repr(self):
+        x = group([add.s(2, 2), add.s(4, 4)])
+        self.assertEqual(repr(x), repr(x.tasks))
+
+    def test_reverse(self):
+        x = group([add.s(2, 2), add.s(4, 4)])
+        self.assertIsInstance(subtask(x), group)
+        self.assertIsInstance(subtask(dict(x)), group)
+
+
+class test_chord(Case):
+
+    def test_reverse(self):
+        x = chord([add.s(2, 2), add.s(4, 4)], body=mul.s(4))
+        self.assertIsInstance(subtask(x), chord)
+        self.assertIsInstance(subtask(dict(x)), chord)
+
+    def test_clone_clones_body(self):
+        x = chord([add.s(2, 2), add.s(4, 4)], body=mul.s(4))
+        y = x.clone()
+        self.assertIsNot(x.kwargs["body"], y.kwargs["body"])
+        y.kwargs.pop("body")
+        z = y.clone()
+        self.assertIsNone(z.kwargs.get("body"))
+
+    def test_links_to_body(self):
+        x = chord([add.s(2, 2), add.s(4, 4)], body=mul.s(4))
+        x.link(div.s(2))
+        self.assertFalse(x.options.get("link"))
+        self.assertTrue(x.kwargs["body"].options["link"])
+
+        x.link_error(div.s(2))
+        self.assertFalse(x.options.get("link_error"))
+        self.assertTrue(x.kwargs["body"].options["link_error"])
+
+        self.assertTrue(x.tasks)
+        self.assertTrue(x.body)
+
+    def test_repr(self):
+        x = chord([add.s(2, 2), add.s(4, 4)], body=mul.s(4))
+        self.assertTrue(repr(x))
+        x.kwargs["body"] = None
+        self.assertIn("without body", repr(x))

+ 0 - 10
celery/tests/test_task/test_task_builtins.py

@@ -1,10 +0,0 @@
-from __future__ import absolute_import
-
-from celery.task import backend_cleanup
-from celery.tests.utils import Case
-
-
-class test_backend_cleanup(Case):
-
-    def test_run(self):
-        backend_cleanup.apply()

+ 38 - 2
celery/tests/test_task/test_task_sets.py

@@ -3,7 +3,9 @@ from __future__ import with_statement
 
 import anyjson
 
-from celery.app import app_or_default
+from mock import Mock
+
+from celery import current_app
 from celery.task import Task
 from celery.task.sets import subtask, TaskSet
 from celery.canvas import Signature
@@ -105,7 +107,7 @@ class test_TaskSet(Case):
         self.assertEqual(len(ts), 3)
 
     def test_respects_ALWAYS_EAGER(self):
-        app = app_or_default()
+        app = current_app
 
         class MockTaskSet(TaskSet):
             applied = 0
@@ -143,6 +145,25 @@ class test_TaskSet(Case):
 
         ts.apply_async(publisher=Publisher())
 
+        # setting current_task
+
+        @current_app.task
+        def xyz():
+            pass
+        from celery.app.state import _tls
+        _tls.current_task = xyz
+        try:
+            ts.apply_async(publisher=Publisher())
+        finally:
+            _tls.current_task = None
+            xyz.request.clear()
+
+        # must close publisher
+        ts._Publisher = Mock()
+        ts._Publisher.return_value = Mock()
+        ts.apply_async()
+        self.assertTrue(ts._Publisher.return_value.close.called)
+
     def test_apply(self):
 
         applied = [0]
@@ -156,3 +177,18 @@ class test_TaskSet(Case):
                         for i in (2, 4, 8)])
         ts.apply()
         self.assertEqual(applied[0], 3)
+
+    def test_set_app(self):
+        ts = TaskSet([])
+        ts.app = 42
+        self.assertEqual(ts._app, 42)
+
+    def test_set_tasks(self):
+        ts = TaskSet([])
+        ts.tasks = [1, 2, 3]
+        self.assertEqual(ts.data, [1, 2, 3])
+
+    def test_set_Publisher(self):
+        ts = TaskSet([])
+        ts.Publisher = 42
+        self.assertEqual(ts._Publisher, 42)

+ 58 - 0
celery/tests/test_utils/test_compat.py

@@ -0,0 +1,58 @@
+from __future__ import absolute_import
+
+
+import celery
+from celery.task.base import Task
+
+from celery.tests.utils import Case
+
+
+class test_MagicModule(Case):
+
+    def test_class_property_set_without_type(self):
+        self.assertTrue(Task.__dict__["app"].__get__(Task()))
+
+    def test_class_property_set_on_class(self):
+        self.assertIs(Task.__dict__["app"].__set__(None, None),
+                      Task.__dict__["app"])
+
+    def test_class_property_set(self):
+
+        class X(Task):
+            pass
+
+        app = celery.Celery(set_as_current=False)
+        Task.__dict__["app"].__set__(X(), app)
+        self.assertEqual(X.app, app)
+
+    def test_dir(self):
+        self.assertTrue(dir(celery.messaging))
+
+    def test_direct(self):
+        import sys
+        prev_celery = sys.modules.pop("celery", None)
+        prev_task = sys.modules.pop("celery.task", None)
+        try:
+            import celery
+            self.assertTrue(celery.task)
+        finally:
+            sys.modules["celery"] = prev_celery
+            sys.modules["celery.task"] = prev_task
+
+    def test_app_attrs(self):
+        self.assertEqual(celery.task.control.broadcast,
+                         celery.current_app.control.broadcast)
+
+    def test_decorators_task(self):
+        @celery.decorators.task
+        def _test_decorators_task():
+            pass
+
+        self.assertTrue(_test_decorators_task.accept_magic_kwargs)
+
+    def test_decorators_periodic_task(self):
+        @celery.decorators.periodic_task(run_every=3600)
+        def _test_decorators_ptask():
+            pass
+
+        self.assertTrue(_test_decorators_ptask.accept_magic_kwargs)

+ 276 - 0
celery/tests/test_utils/test_local.py

@@ -0,0 +1,276 @@
+from __future__ import absolute_import
+from __future__ import with_statement
+
+from celery.local import Proxy, PromiseProxy, maybe_evaluate, try_import
+
+from celery.tests.utils import Case
+
+
+class test_try_import(Case):
+
+    def test_imports(self):
+        self.assertTrue(try_import(__name__))
+
+    def test_when_default(self):
+        default = object()
+        self.assertIs(try_import("foobar.awqewqe.asdwqewq", default), default)
+
+
+class test_Proxy(Case):
+
+    def test_name(self):
+
+        def real():
+            """real function"""
+            return "REAL"
+
+        x = Proxy(lambda: real, name="xyz")
+        self.assertEqual(x.__name__, "xyz")
+
+        y = Proxy(lambda: real)
+        self.assertEqual(y.__name__, "real")
+
+        self.assertEqual(x.__doc__, "real function")
+
+        self.assertEqual(x.__class__, type(real))
+        self.assertEqual(x.__dict__, real.__dict__)
+        self.assertEqual(repr(x), repr(real))
+
+    def test_nonzero(self):
+
+        class X(object):
+
+            def __nonzero__(self):
+                return False
+
+        x = Proxy(lambda: X())
+        self.assertFalse(x)
+
+    def test_slots(self):
+
+        class X(object):
+            __slots__ = ()
+
+        x = Proxy(X)
+        with self.assertRaises(AttributeError):
+            x.__dict__
+
+    def test_unicode(self):
+
+        class X(object):
+
+            def __unicode__(self):
+                return u"UNICODE"
+
+            def __repr__(self):
+                return "REPR"
+
+        x = Proxy(lambda: X())
+        self.assertEqual(unicode(x), u"UNICODE")
+        del(X.__unicode__)
+        self.assertEqual(unicode(x), "REPR")
+
+    def test_dir(self):
+
+        class X(object):
+
+            def __dir__(self):
+                return ["a", "b", "c"]
+
+        x = Proxy(lambda: X())
+        self.assertListEqual(dir(x), ["a", "b", "c"])
+
+        class Y(object):
+
+            def __dir__(self):
+                raise RuntimeError()
+        y = Proxy(lambda: Y())
+        self.assertListEqual(dir(y), [])
+
+    def test_getsetdel_attr(self):
+
+        class X(object):
+            a = 1
+            b = 2
+            c = 3
+
+            def __dir__(self):
+                return ["a", "b", "c"]
+
+        v = X()
+
+        x = Proxy(lambda: v)
+        self.assertListEqual(x.__members__, ["a", "b", "c"])
+        self.assertEqual(x.a, 1)
+        self.assertEqual(x.b, 2)
+        self.assertEqual(x.c, 3)
+
+        setattr(x, "a", 10)
+        self.assertEqual(x.a, 10)
+
+        del(x.a)
+        self.assertEqual(x.a, 1)
+
+    def test_dictproxy(self):
+        v = {}
+        x = Proxy(lambda: v)
+        x["foo"] = 42
+        self.assertEqual(x["foo"], 42)
+        self.assertEqual(len(x), 1)
+        self.assertIn("foo", x)
+        del(x["foo"])
+        with self.assertRaises(KeyError):
+            x["foo"]
+        self.assertTrue(iter(x))
+
+    def test_listproxy(self):
+        v = []
+        x = Proxy(lambda: v)
+        x.append(1)
+        x.extend([2, 3, 4])
+        self.assertEqual(x[0], 1)
+        self.assertEqual(x[:-1], [1, 2, 3])
+        del(x[-1])
+        self.assertEqual(x[:-1], [1, 2])
+        x[0] = 10
+        self.assertEqual(x[0], 10)
+        self.assertIn(10, x)
+        self.assertEqual(len(x), 3)
+        self.assertTrue(iter(x))
+
+    def test_int(self):
+        self.assertEqual(Proxy(lambda: 10) + 1, Proxy(lambda: 11))
+        self.assertEqual(Proxy(lambda: 10) - 1, Proxy(lambda: 9))
+        self.assertEqual(Proxy(lambda: 10) * 2, Proxy(lambda: 20))
+        self.assertEqual(Proxy(lambda: 10) ** 2, Proxy(lambda: 100))
+        self.assertEqual(Proxy(lambda: 20) / 2, Proxy(lambda: 10))
+        self.assertEqual(Proxy(lambda: 20) // 2, Proxy(lambda: 10))
+        self.assertEqual(Proxy(lambda: 11) % 2, Proxy(lambda: 1))
+        self.assertEqual(Proxy(lambda: 10) << 2, Proxy(lambda: 40))
+        self.assertEqual(Proxy(lambda: 10) >> 2, Proxy(lambda: 2))
+        self.assertEqual(Proxy(lambda: 10) ^ 7, Proxy(lambda: 13))
+        self.assertEqual(Proxy(lambda: 10) | 40, Proxy(lambda: 42))
+        self.assertEqual(~Proxy(lambda: 10), Proxy(lambda: -11))
+        self.assertEqual(-Proxy(lambda: 10), Proxy(lambda: -10))
+        self.assertEqual(+Proxy(lambda: -10), Proxy(lambda: -10))
+        self.assertTrue(Proxy(lambda: 10) < Proxy(lambda: 20))
+        self.assertTrue(Proxy(lambda: 20) > Proxy(lambda: 10))
+        self.assertTrue(Proxy(lambda: 10) >= Proxy(lambda: 10))
+        self.assertTrue(Proxy(lambda: 10) <= Proxy(lambda: 10))
+        self.assertTrue(Proxy(lambda: 10) == Proxy(lambda: 10))
+        self.assertTrue(Proxy(lambda: 20) != Proxy(lambda: 10))
+
+        x = Proxy(lambda: 10)
+        x -= 1
+        self.assertEqual(x, 9)
+        x = Proxy(lambda: 9)
+        x += 1
+        self.assertEqual(x, 10)
+        x = Proxy(lambda: 10)
+        x *= 2
+        self.assertEqual(x, 20)
+        x = Proxy(lambda: 20)
+        x /= 2
+        self.assertEqual(x, 10)
+        x = Proxy(lambda: 10)
+        x %= 2
+        self.assertEqual(x, 0)
+        x = Proxy(lambda: 10)
+        x <<= 3
+        self.assertEqual(x, 80)
+        x = Proxy(lambda: 80)
+        x >>= 4
+        self.assertEqual(x, 5)
+        x = Proxy(lambda: 5)
+        x ^= 1
+        self.assertEqual(x, 4)
+        x = Proxy(lambda: 4)
+        x **= 4
+        self.assertEqual(x, 256)
+        x = Proxy(lambda: 256)
+        x //= 2
+        self.assertEqual(x, 128)
+        x = Proxy(lambda: 128)
+        x |= 2
+        self.assertEqual(x, 130)
+        x = Proxy(lambda: 130)
+        x &= 10
+        self.assertEqual(x, 2)
+
+        x = Proxy(lambda: 10)
+        self.assertEqual(type(x.__float__()), float)
+        self.assertEqual(type(x.__int__()), int)
+        self.assertEqual(type(x.__long__()), long)
+        self.assertTrue(hex(x))
+        self.assertTrue(oct(x))
+
+    def test_hash(self):
+
+        class X(object):
+
+            def __hash__(self):
+                return 1234
+
+        self.assertEqual(hash(Proxy(lambda: X())), 1234)
+
+    def test_call(self):
+
+        class X(object):
+
+            def __call__(self):
+                return 1234
+
+        self.assertEqual(Proxy(lambda: X())(), 1234)
+
+    def test_context(self):
+
+        class X(object):
+            entered = exited = False
+
+            def __enter__(self):
+                self.entered = True
+                return 1234
+
+            def __exit__(self, *exc_info):
+                self.exited = True
+
+        v = X()
+        x = Proxy(lambda: v)
+        with x as val:
+            self.assertEqual(val, 1234)
+        self.assertTrue(x.entered)
+        self.assertTrue(x.exited)
+
+    def test_reduce(self):
+
+        class X(object):
+
+            def __reduce__(self):
+                return 123
+
+        x = Proxy(lambda: X())
+        self.assertEqual(x.__reduce__(), 123)
+
+
+class test_PromiseProxy(Case):
+
+    def test_only_evaluated_once(self):
+
+        class X(object):
+            attr = 123
+            evals = 0
+
+            def __init__(self):
+                self.__class__.evals += 1
+
+        p = PromiseProxy(X)
+        self.assertEqual(p.attr, 123)
+        self.assertEqual(p.attr, 123)
+        self.assertEqual(X.evals, 1)
+
+    def test_maybe_evaluate(self):
+        x = PromiseProxy(lambda: 30)
+        self.assertEqual(maybe_evaluate(x), 30)
+        self.assertEqual(maybe_evaluate(x), 30)
+
+        self.assertEqual(maybe_evaluate(30), 30)

+ 0 - 4
celery/tests/test_utils/test_timer2.py

@@ -48,14 +48,10 @@ class test_Schedule(Case):
 
         timer2.mktime = _overflow
         try:
-            print("+S1")
             s.enter(timer2.Entry(lambda: None, (), {}),
                     eta=datetime.now())
-            print("-S1")
-            print("+S2")
             s.enter(timer2.Entry(lambda: None, (), {}),
                     eta=None)
-            print("-S2")
             s.on_error = None
             with self.assertRaises(OverflowError):
                 s.enter(timer2.Entry(lambda: None, (), {}),

+ 36 - 1
celery/tests/test_utils/test_utils_imports.py

@@ -1,6 +1,13 @@
 from __future__ import absolute_import
 
-from celery.utils.imports import qualname, symbol_by_name
+from mock import Mock, patch
+
+from celery.utils.imports import (
+    qualname,
+    symbol_by_name,
+    reload_from_cwd,
+    module_file,
+)
 
 from celery.tests.utils import Case
 
@@ -15,3 +22,31 @@ class test_import_utils(Case):
     def test_symbol_by_name__instance_returns_instance(self):
         instance = object()
         self.assertIs(symbol_by_name(instance), instance)
+
+    def test_symbol_by_name_returns_default(self):
+        default = object()
+        self.assertIs(symbol_by_name("xyz.ryx.qedoa.weq:foz",
+                        default=default), default)
+
+    def test_symbol_by_name_package(self):
+        from celery.worker import WorkController
+        self.assertIs(symbol_by_name(".worker:WorkController",
+                    package="celery"), WorkController)
+
+    @patch("celery.utils.imports.reload")
+    def test_reload_from_cwd(self, reload):
+        reload_from_cwd("foo")
+        self.assertTrue(reload.called)
+
+    def test_reload_from_cwd_custom_reloader(self):
+        reload = Mock()
+        reload_from_cwd("foo", reload)
+        self.assertTrue(reload.called)
+
+    def test_module_file(self):
+        m1 = Mock()
+        m1.__file__ = "/opt/foo/xyz.pyc"
+        self.assertEqual(module_file(m1), "/opt/foo/xyz.py")
+        m2 = Mock()
+        m2.__file__ = "/opt/foo/xyz.py"
+        self.assertEqual(module_file(m1), "/opt/foo/xyz.py")

+ 0 - 1
celery/tests/utils.py

@@ -434,7 +434,6 @@ platform_pyimp = partial(
 )
 
 
-
 @contextmanager
 def sys_platform(value):
     prev, sys.platform = sys.platform, value

+ 1 - 3
celery/utils/__init__.py

@@ -22,12 +22,9 @@ from functools import partial, wraps
 from inspect import getargspec
 from pprint import pprint
 
-from billiard.util import register_after_fork
-
 from celery.exceptions import CPendingDeprecationWarning, CDeprecationWarning
 from .compat import StringIO
 
-from .imports import symbol_by_name, qualname
 from .functional import noop
 
 PENDING_DEPRECATION_FMT = """
@@ -61,6 +58,7 @@ def deprecated(description=None, deprecation=None, removal=None,
 
         @wraps(fun)
         def __inner(*args, **kwargs):
+            from .imports import qualname
             warn_deprecated(description=description or qualname(fun),
                             deprecation=deprecation,
                             removal=removal,

+ 2 - 2
celery/utils/imports.py

@@ -15,7 +15,7 @@ class NotAPackage(Exception):
     pass
 
 
-if sys.version_info >= (3, 3):
+if sys.version_info >= (3, 3):  # pragma: no cover
 
     def qualname(obj):
         return obj.__qualname__
@@ -109,7 +109,7 @@ def cwd_in_path():
         finally:
             try:
                 sys.path.remove(cwd)
-            except ValueError:
+            except ValueError:  # pragma: no cover
                 pass
 
 

+ 18 - 25
celery/utils/log.py

@@ -21,14 +21,14 @@ is_py3k = sys.version_info[0] == 3
 # Every logger in the celery package inherits from the "celery"
 # logger, and every task logger inherits from the "celery.task"
 # logger.
-logger = _get_logger("celery")
+base_logger = logger = _get_logger("celery")
 mp_logger = _get_logger("multiprocessing")
 
 
 def get_logger(name):
     l = _get_logger(name)
-    if l.parent is logging.root and l is not logger:
-        l.parent = logger
+    if logging.root not in (l, l.parent) and l is not base_logger:
+        l.parent = base_logger
     return l
 task_logger = get_logger("celery.task")
 
@@ -163,34 +163,27 @@ class LoggingProxy(object):
         return None
 
 
-def _patch_logger_class():
+def ensure_process_aware_logger():
     """Make sure process name is recorded when loggers are used."""
-    logging._acquireLock()
-    try:
-        OldLoggerClass = logging.getLoggerClass()
-        if not getattr(OldLoggerClass, '_process_aware', False):
-
-            class ProcessAwareLogger(OldLoggerClass):
+    global _process_aware
+    if not _process_aware:
+        logging._acquireLock()
+        try:
+            _process_aware = True
+            Logger = logging.getLoggerClass()
+            if getattr(Logger, '_process_aware', False):  # pragma: no cover
+                return
+
+            class ProcessAwareLogger(Logger):
                 _process_aware = True
 
                 def makeRecord(self, *args, **kwds):
-                    record = OldLoggerClass.makeRecord(self, *args, **kwds)
-                    if current_process:
-                        record.processName = current_process()._name
-                    else:
-                        record.processName = ""
+                    record = Logger.makeRecord(self, *args, **kwds)
+                    record.processName = current_process()._name
                     return record
             logging.setLoggerClass(ProcessAwareLogger)
-    finally:
-        logging._releaseLock()
-
-
-def ensure_process_aware_logger():
-    global _process_aware
-
-    if not _process_aware:
-        _patch_logger_class()
-        _process_aware = True
+        finally:
+            logging._releaseLock()
 
 
 def get_multiprocessing_logger():

+ 0 - 6
setup.cfg

@@ -9,17 +9,11 @@ cover3-exclude = celery
                  celery.bin.celeryd_detach
                  celery.bin.celeryctl
                  celery.bin.camqadm
-                 celery.execute
-                 celery.local
                  celery.platforms
-                 celery.utils.encoding
-                 celery.utils.patch
                  celery.utils.compat
                  celery.utils.mail
-                 celery.utils.functional
                  celery.utils.dispatch*
                  celery.utils.term
-                 celery.messaging
                  celery.db.a805d4bd
                  celery.db.dfd042c7
                  celery.contrib*