Ask Solem před 13 roky
rodič
revize
f51114dd9a

+ 15 - 3
celery/__compat__.py

@@ -15,7 +15,7 @@ DEFAULT_ATTRS = set(["__file__", "__path__", "__doc__", "__all__"])
 
 
 # im_func is no longer available in Py3.
 # im_func is no longer available in Py3.
 # instead the unbound method itself can be used.
 # instead the unbound method itself can be used.
-if sys.version_info[0] == 3:
+if sys.version_info[0] == 3:  # pragma: no cover
     def fun_of_method(method):
     def fun_of_method(method):
         return method
         return method
 else:
 else:
@@ -69,6 +69,17 @@ COMPAT_MODULES = {
             "tasks": "tasks",
             "tasks": "tasks",
         },
         },
     },
     },
+    "celery.task": {
+        "control": {
+            "broadcast": "control.broadcast",
+            "rate_limit": "control.rate_limit",
+            "time_limit": "control.time_limit",
+            "ping": "control.ping",
+            "revoke": "control.revoke",
+            "discard_all": "control.discard_all",
+            "inspect": "control.inspect",
+        }
+    }
 }
 }
 
 
 
 
@@ -158,8 +169,9 @@ def get_compat_module(pkg, name):
             return Proxy(getappattr, (attr.split('.'), ))
             return Proxy(getappattr, (attr.split('.'), ))
         return attr
         return attr
 
 
-    return create_module(name, COMPAT_MODULES[pkg.__name__][name],
-                         pkg=pkg, prepare_attr=prepare)
+    attrs = dict(COMPAT_MODULES[pkg.__name__][name])
+    attrs["__all__"] = attrs.keys()
+    return create_module(name, attrs, pkg=pkg, prepare_attr=prepare)
 
 
 
 
 def get_origins(defs):
 def get_origins(defs):

+ 0 - 1
celery/__init__.py

@@ -17,7 +17,6 @@ __docformat__ = "restructuredtext"
 # Lazy loading
 # Lazy loading
 from .__compat__ import recreate_module
 from .__compat__ import recreate_module
 
 
-
 old_module, new_module = recreate_module(__name__,
 old_module, new_module = recreate_module(__name__,
     by_module={
     by_module={
         "celery.app":       ["Celery", "bugreport"],
         "celery.app":       ["Celery", "bugreport"],

+ 1 - 1
celery/app/__init__.py

@@ -43,7 +43,7 @@ set_default_app(Celery("default", loader=default_loader,
 
 
 
 
 def bugreport():
 def bugreport():
-    return current_app.bugreport()
+    return current_app().bugreport()
 
 
 
 
 def _app_or_default(app=None):
 def _app_or_default(app=None):

+ 4 - 2
celery/app/amqp.py

@@ -16,7 +16,7 @@ from datetime import timedelta
 from kombu import BrokerConnection, Exchange
 from kombu import BrokerConnection, Exchange
 from kombu import compat as messaging
 from kombu import compat as messaging
 from kombu import pools
 from kombu import pools
-from kombu.common import maybe_declare
+from kombu.common import declaration_cached, maybe_declare
 
 
 from celery import signals
 from celery import signals
 from celery.utils import cached_property, lpmerge, uuid
 from celery.utils import cached_property, lpmerge, uuid
@@ -35,6 +35,7 @@ QUEUE_FORMAT = """
 binding:%(binding_key)s
 binding:%(binding_key)s
 """
 """
 
 
+
 def extract_msg_options(options, keep=MSG_OPTIONS):
 def extract_msg_options(options, keep=MSG_OPTIONS):
     """Extracts known options to `basic_publish` from a dict,
     """Extracts known options to `basic_publish` from a dict,
     and returns a new dict."""
     and returns a new dict."""
@@ -157,7 +158,8 @@ class TaskPublisher(messaging.Publisher):
         super(TaskPublisher, self).__init__(*args, **kwargs)
         super(TaskPublisher, self).__init__(*args, **kwargs)
 
 
     def declare(self):
     def declare(self):
-        if self.exchange.name and not declaration_cached(self.exchange):
+        if self.exchange.name and \
+                not declaration_cached(self.exchange, self.channel):
             super(TaskPublisher, self).declare()
             super(TaskPublisher, self).declare()
 
 
     def _get_queue(self, name):
     def _get_queue(self, name):

+ 8 - 6
celery/app/base.py

@@ -19,13 +19,14 @@ from contextlib import contextmanager
 from copy import deepcopy
 from copy import deepcopy
 from functools import wraps
 from functools import wraps
 
 
+from billiard.util import register_after_fork
 from kombu.clocks import LamportClock
 from kombu.clocks import LamportClock
+from kombu.utils import cached_property
 
 
 from celery import platforms
 from celery import platforms
 from celery.exceptions import AlwaysEagerIgnored
 from celery.exceptions import AlwaysEagerIgnored
 from celery.loaders import get_loader_cls
 from celery.loaders import get_loader_cls
 from celery.local import PromiseProxy, maybe_evaluate
 from celery.local import PromiseProxy, maybe_evaluate
-from celery.utils import cached_property, register_after_fork
 from celery.utils.functional import first
 from celery.utils.functional import first
 from celery.utils.imports import instantiate, symbol_by_name
 from celery.utils.imports import instantiate, symbol_by_name
 
 
@@ -159,7 +160,7 @@ class Celery(object):
             eta=None, task_id=None, publisher=None, connection=None,
             eta=None, task_id=None, publisher=None, connection=None,
             connect_timeout=None, result_cls=None, expires=None,
             connect_timeout=None, result_cls=None, expires=None,
             queues=None, **options):
             queues=None, **options):
-        if self.conf.CELERY_ALWAYS_EAGER:
+        if self.conf.CELERY_ALWAYS_EAGER:  # pragma: no cover
             warnings.warn(AlwaysEagerIgnored(
             warnings.warn(AlwaysEagerIgnored(
                 "CELERY_ALWAYS_EAGER has no effect on send_task"))
                 "CELERY_ALWAYS_EAGER has no effect on send_task"))
 
 
@@ -234,9 +235,6 @@ class Celery(object):
 
 
     def prepare_config(self, c):
     def prepare_config(self, c):
         """Prepare configuration before it is merged with the defaults."""
         """Prepare configuration before it is merged with the defaults."""
-        if self._preconf:
-            for key, value in self._preconf.iteritems():
-                setattr(c, key, value)
         return find_deprecated_settings(c)
         return find_deprecated_settings(c)
 
 
     def now(self):
     def now(self):
@@ -275,8 +273,12 @@ class Celery(object):
         return backend(app=self, url=url)
         return backend(app=self, url=url)
 
 
     def _get_config(self):
     def _get_config(self):
-        return Settings({}, [self.prepare_config(self.loader.conf),
+        s = Settings({}, [self.prepare_config(self.loader.conf),
                              deepcopy(DEFAULTS)])
                              deepcopy(DEFAULTS)])
+        if self._preconf:
+            for key, value in self._preconf.iteritems():
+                setattr(s, key, value)
+        return s
 
 
     def _after_fork(self, obj_):
     def _after_fork(self, obj_):
         if self._pool:
         if self._pool:

+ 18 - 11
celery/app/builtins.py

@@ -39,7 +39,11 @@ def add_backend_cleanup_task(app):
     may even clean up in realtime so that a periodic cleanup is not necessary.
     may even clean up in realtime so that a periodic cleanup is not necessary.
 
 
     """
     """
-    return app.task(name="celery.backend_cleanup")(app.backend.cleanup)
+
+    @app.task(name="celery.backend_cleanup")
+    def backend_cleanup():
+        app.backend.cleanup()
+    return backend_cleanup
 
 
 
 
 @builtin_task
 @builtin_task
@@ -79,17 +83,19 @@ def add_group_task(app):
         def run(self, tasks, result):
         def run(self, tasks, result):
             app = self.app
             app = self.app
             result = from_serializable(result)
             result = from_serializable(result)
+            if self.request.is_eager or app.conf.CELERY_ALWAYS_EAGER:
+                return app.TaskSetResult(result.id,
+                        [subtask(task).apply(taskset_id=self.request.taskset)
+                            for task in tasks])
             with app.pool.acquire(block=True) as conn:
             with app.pool.acquire(block=True) as conn:
                 with app.amqp.TaskPublisher(conn) as publisher:
                 with app.amqp.TaskPublisher(conn) as publisher:
-                    res_ = [subtask(task).apply_async(
-                                        taskset_id=self.request.taskset,
-                                        publisher=publisher)
-                                for task in tasks]
+                    [subtask(task).apply_async(
+                                    taskset_id=self.request.taskset,
+                                    publisher=publisher)
+                            for task in tasks]
             parent = get_current_task()
             parent = get_current_task()
             if parent:
             if parent:
                 parent.request.children.append(result)
                 parent.request.children.append(result)
-            if self.request.is_eager or app.conf.CELERY_ALWAYS_EAGER:
-                return app.TaskSetResult(result.id, res_)
             return result
             return result
 
 
         def prepare(self, options, tasks, **kwargs):
         def prepare(self, options, tasks, **kwargs):
@@ -111,8 +117,7 @@ def add_group_task(app):
 
 
         def apply(self, args=(), kwargs={}, **options):
         def apply(self, args=(), kwargs={}, **options):
             tasks, result = self.prepare(options, **kwargs)
             tasks, result = self.prepare(options, **kwargs)
-            return super(Group, self).apply((tasks, result), {"eager": True},
-                                            **options)
+            return super(Group, self).apply((tasks, result), **options)
 
 
     return Group
     return Group
 
 
@@ -177,8 +182,10 @@ def add_chord_task(app):
             body = maybe_subtask(kwargs["body"])
             body = maybe_subtask(kwargs["body"])
 
 
             callback_id = body.options.setdefault("task_id", task_id or uuid())
             callback_id = body.options.setdefault("task_id", task_id or uuid())
-            super(Chord, self).apply_async(args, kwargs, **options)
-            return self.AsyncResult(callback_id)
+            parent = super(Chord, self).apply_async(args, kwargs, **options)
+            body_result = self.AsyncResult(callback_id)
+            body_result.parent = parent
+            return body_result
 
 
         def apply(self, args=(), kwargs={}, **options):
         def apply(self, args=(), kwargs={}, **options):
             body = kwargs["body"]
             body = kwargs["body"]

+ 1 - 0
celery/apps/worker.py

@@ -279,6 +279,7 @@ install_worker_term_hard_handler = partial(
     _shutdown_handler, sig="SIGQUIT", how="terminate", exc=SystemTerminate,
     _shutdown_handler, sig="SIGQUIT", how="terminate", exc=SystemTerminate,
 )
 )
 
 
+
 def on_SIGINT(worker):
 def on_SIGINT(worker):
     print("celeryd: Hitting Ctrl+C again will terminate all running tasks!")
     print("celeryd: Hitting Ctrl+C again will terminate all running tasks!")
     install_worker_term_hard_handler(worker, sig="SIGINT")
     install_worker_term_hard_handler(worker, sig="SIGINT")

+ 2 - 1
celery/beat.py

@@ -448,7 +448,8 @@ class _Threaded(threading.Thread):
 
 
 supports_fork = True
 supports_fork = True
 try:
 try:
-    import _multiprocessing
+    from billiard._ext import _billiard
+    supports_fork = True if _billiard else False
 except ImportError:
 except ImportError:
     supports_fork = False
     supports_fork = False
 
 

+ 10 - 7
celery/canvas.py

@@ -162,10 +162,12 @@ class Signature(dict):
 
 
     def __or__(self, other):
     def __or__(self, other):
         if isinstance(other, chain):
         if isinstance(other, chain):
-            return chain(self.tasks + other.tasks)
+            return chain(*self.tasks + other.tasks)
         elif isinstance(other, Signature):
         elif isinstance(other, Signature):
+            if isinstance(self, chain):
+                return chain(*self.tasks + (other, ))
             return chain(self, other)
             return chain(self, other)
-        return NotImplementedError
+        return NotImplemented
 
 
     def __invert__(self):
     def __invert__(self):
         return self.apply_async().get()
         return self.apply_async().get()
@@ -241,7 +243,8 @@ class chord(Signature):
     @classmethod
     @classmethod
     def from_dict(self, d):
     def from_dict(self, d):
         kwargs = d["kwargs"]
         kwargs = d["kwargs"]
-        return chord(kwargs["header"], kwargs["body"], **kwdict(d["options"]))
+        return chord(kwargs["header"], kwargs.get("body"),
+                     **kwdict(d["options"]))
 
 
     def __call__(self, body=None, **options):
     def __call__(self, body=None, **options):
         _chord = self.Chord
         _chord = self.Chord
@@ -254,10 +257,10 @@ class chord(Signature):
 
 
     def clone(self, *args, **kwargs):
     def clone(self, *args, **kwargs):
         s = Signature.clone(self, *args, **kwargs)
         s = Signature.clone(self, *args, **kwargs)
-        # need make copy of body
+        # need to make copy of body
         try:
         try:
-            kwargs["body"] = kwargs["body"].clone()
-        except KeyError:
+            s.kwargs["body"] = s.kwargs["body"].clone()
+        except (AttributeError, KeyError):
             pass
             pass
         return s
         return s
 
 
@@ -280,7 +283,7 @@ class chord(Signature):
 
 
     @property
     @property
     def body(self):
     def body(self):
-        return self.kwargs["body"]
+        return self.kwargs.get("body")
 Signature.register_type(chord)
 Signature.register_type(chord)
 
 
 
 

+ 0 - 2
celery/concurrency/eventlet.py

@@ -8,8 +8,6 @@ if not os.environ.get("EVENTLET_NOPATCH"):
     eventlet.monkey_patch()
     eventlet.monkey_patch()
     eventlet.debug.hub_prevent_multiple_readers(False)
     eventlet.debug.hub_prevent_multiple_readers(False)
 
 
-import sys
-
 from time import time
 from time import time
 
 
 from celery import signals
 from celery import signals

+ 0 - 2
celery/concurrency/gevent.py

@@ -6,8 +6,6 @@ if not os.environ.get("GEVENT_NOPATCH"):
     from gevent import monkey
     from gevent import monkey
     monkey.patch_all()
     monkey.patch_all()
 
 
-import sys
-
 from time import time
 from time import time
 
 
 from celery.utils import timer2
 from celery.utils import timer2

+ 1 - 0
celery/loaders/__init__.py

@@ -19,6 +19,7 @@ LOADER_ALIASES = {"app": "celery.loaders.app:AppLoader",
                   "default": "celery.loaders.default:Loader",
                   "default": "celery.loaders.default:Loader",
                   "django": "djcelery.loaders:DjangoLoader"}
                   "django": "djcelery.loaders:DjangoLoader"}
 
 
+
 def get_loader_cls(loader):
 def get_loader_cls(loader):
     """Get loader class by name/alias"""
     """Get loader class by name/alias"""
     return symbol_by_name(loader, LOADER_ALIASES)
     return symbol_by_name(loader, LOADER_ALIASES)

+ 10 - 9
celery/local.py

@@ -34,12 +34,13 @@ class Proxy(object):
         object.__setattr__(self, '_Proxy__local', local)
         object.__setattr__(self, '_Proxy__local', local)
         object.__setattr__(self, '_Proxy__args', args or ())
         object.__setattr__(self, '_Proxy__args', args or ())
         object.__setattr__(self, '_Proxy__kwargs', kwargs or {})
         object.__setattr__(self, '_Proxy__kwargs', kwargs or {})
-        object.__setattr__(self, '__custom_name__', name)
+        if name is not None:
+            object.__setattr__(self, '__custom_name__', name)
 
 
     @property
     @property
     def __name__(self):
     def __name__(self):
         try:
         try:
-            return object.__getattr__(self, "__custom_name__")
+            return self.__custom_name__
         except AttributeError:
         except AttributeError:
             return self._get_current_object().__name__
             return self._get_current_object().__name__
 
 
@@ -67,32 +68,32 @@ class Proxy(object):
     def __dict__(self):
     def __dict__(self):
         try:
         try:
             return self._get_current_object().__dict__
             return self._get_current_object().__dict__
-        except RuntimeError:
+        except RuntimeError:  # pragma: no cover
             raise AttributeError('__dict__')
             raise AttributeError('__dict__')
 
 
     def __repr__(self):
     def __repr__(self):
         try:
         try:
             obj = self._get_current_object()
             obj = self._get_current_object()
-        except RuntimeError:
+        except RuntimeError:  # pragma: no cover
             return '<%s unbound>' % self.__class__.__name__
             return '<%s unbound>' % self.__class__.__name__
         return repr(obj)
         return repr(obj)
 
 
     def __nonzero__(self):
     def __nonzero__(self):
         try:
         try:
             return bool(self._get_current_object())
             return bool(self._get_current_object())
-        except RuntimeError:
+        except RuntimeError:  # pragma: no cover
             return False
             return False
 
 
     def __unicode__(self):
     def __unicode__(self):
         try:
         try:
             return unicode(self._get_current_object())
             return unicode(self._get_current_object())
-        except RuntimeError:
+        except RuntimeError:  # pragma: no cover
             return repr(self)
             return repr(self)
 
 
     def __dir__(self):
     def __dir__(self):
         try:
         try:
             return dir(self._get_current_object())
             return dir(self._get_current_object())
-        except RuntimeError:
+        except RuntimeError:  # pragma: no cover
             return []
             return []
 
 
     def __getattr__(self, name):
     def __getattr__(self, name):
@@ -155,8 +156,8 @@ class Proxy(object):
     __hex__ = lambda x: hex(x._get_current_object())
     __hex__ = lambda x: hex(x._get_current_object())
     __index__ = lambda x: x._get_current_object().__index__()
     __index__ = lambda x: x._get_current_object().__index__()
     __coerce__ = lambda x, o: x.__coerce__(x, o)
     __coerce__ = lambda x, o: x.__coerce__(x, o)
-    __enter__ = lambda x: x.__enter__()
-    __exit__ = lambda x, *a, **kw: x.__exit__(*a, **kw)
+    __enter__ = lambda x: x._get_current_object().__enter__()
+    __exit__ = lambda x, *a, **kw: x._get_current_object().__exit__(*a, **kw)
     __reduce__ = lambda x: x._get_current_object().__reduce__()
     __reduce__ = lambda x: x._get_current_object().__reduce__()
 
 
 
 

+ 1 - 0
celery/security/serialization.py

@@ -1,4 +1,5 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
+from __future__ import with_statement
 
 
 import base64
 import base64
 
 

+ 2 - 2
celery/security/utils.py

@@ -8,8 +8,8 @@ from celery.exceptions import SecurityError
 
 
 try:
 try:
     from OpenSSL import crypto
     from OpenSSL import crypto
-except ImportError:
-    crypto = None  # noqa
+except ImportError:  # pragma: no cover
+    crypto = None    # noqa
 
 
 
 
 @contextmanager
 @contextmanager

+ 0 - 13
celery/task/control.py

@@ -1,13 +0,0 @@
-from __future__ import absolute_import
-
-from celery import current_app
-from celery.local import Proxy
-
-
-broadcast = Proxy(lambda: current_app.control.broadcast)
-rate_limit = Proxy(lambda: current_app.control.rate_limit)
-time_limit = Proxy(lambda: current_app.control.time_limit)
-ping = Proxy(lambda: current_app.control.ping)
-revoke = Proxy(lambda: current_app.control.revoke)
-discard_all = Proxy(lambda: current_app.control.discard_all)
-inspect = Proxy(lambda: current_app.control.inspect)

+ 32 - 2
celery/tests/test_app/__init__.py

@@ -3,11 +3,13 @@ from __future__ import with_statement
 
 
 import os
 import os
 
 
-from mock import Mock
+from mock import Mock, patch
+from pickle import loads, dumps
 
 
 from celery import Celery
 from celery import Celery
 from celery import app as _app
 from celery import app as _app
 from celery.app import defaults
 from celery.app import defaults
+from celery.app import state
 from celery.loaders.base import BaseLoader
 from celery.loaders.base import BaseLoader
 from celery.platforms import pyimplementation
 from celery.platforms import pyimplementation
 from celery.utils.serialization import pickle
 from celery.utils.serialization import pickle
@@ -36,6 +38,15 @@ def _get_test_config():
 test_config = _get_test_config()
 test_config = _get_test_config()
 
 
 
 
+class test_module(Case):
+
+    def test_default_app(self):
+        self.assertEqual(_app.default_app, state.default_app)
+
+    def test_bugreport(self):
+        self.assertTrue(_app.bugreport())
+
+
 class test_App(Case):
 class test_App(Case):
 
 
     def setUp(self):
     def setUp(self):
@@ -52,6 +63,10 @@ class test_App(Case):
         task = app.task(fun)
         task = app.task(fun)
         self.assertEqual(task.name, app.main + ".fun")
         self.assertEqual(task.name, app.main + ".fun")
 
 
+    def test_with_broker(self):
+        app = Celery(set_as_current=False, broker="foo://baribaz")
+        self.assertEqual(app.conf.BROKER_HOST, "foo://baribaz")
+
     def test_repr(self):
     def test_repr(self):
         self.assertTrue(repr(self.app))
         self.assertTrue(repr(self.app))
 
 
@@ -148,6 +163,22 @@ class test_App(Case):
         self.app.config_from_object(Object(CARROT_BACKEND="set_by_us"))
         self.app.config_from_object(Object(CARROT_BACKEND="set_by_us"))
         self.assertEqual(self.app.conf.BROKER_TRANSPORT, "set_by_us")
         self.assertEqual(self.app.conf.BROKER_TRANSPORT, "set_by_us")
 
 
+    def test_WorkController(self):
+        x = self.app.Worker()
+        self.assertIs(x.app, self.app)
+
+    def test_AsyncResult(self):
+        x = self.app.AsyncResult("1")
+        self.assertIs(x.app, self.app)
+        r = loads(dumps(x))
+        # not set as current, so ends up as default app after reduce
+        self.assertIs(r.app, state.default_app)
+
+    @patch("celery.bin.celery.CeleryCommand.execute_from_commandline")
+    def test_start(self, execute):
+        self.app.start()
+        self.assertTrue(execute.called)
+
     def test_mail_admins(self):
     def test_mail_admins(self):
 
 
         class Loader(BaseLoader):
         class Loader(BaseLoader):
@@ -208,7 +239,6 @@ class test_App(Case):
         self.assertTrue(self.app.bugreport())
         self.assertTrue(self.app.bugreport())
 
 
     def test_send_task_sent_event(self):
     def test_send_task_sent_event(self):
-        from celery.app import amqp
 
 
         class Dispatcher(object):
         class Dispatcher(object):
             sent = []
             sent = []

+ 1 - 1
celery/tests/test_app/test_beat.py

@@ -306,7 +306,7 @@ class test_EmbeddedService(Case):
 
 
     def test_start_stop_process(self):
     def test_start_stop_process(self):
         try:
         try:
-            import _multiprocessing
+            import _multiprocessing  # noqa
         except ImportError:
         except ImportError:
             raise SkipTest("multiprocessing not available")
             raise SkipTest("multiprocessing not available")
 
 

+ 113 - 0
celery/tests/test_app/test_builtins.py

@@ -0,0 +1,113 @@
+from __future__ import absolute_import
+
+from mock import Mock
+
+from celery import current_app as app, group, task, chord
+from celery.app import builtins
+from celery.app.state import _tls
+from celery.tests.utils import Case
+
+
+@task
+def add(x, y):
+    return x + y
+
+
+@task
+def xsum(x):
+    return sum(x)
+
+
+class test_backend_cleanup(Case):
+
+    def test_run(self):
+        prev = app.backend
+        app.backend.cleanup = Mock()
+        app.backend.cleanup.__name__ = "cleanup"
+        try:
+            cleanup_task = builtins.add_backend_cleanup_task(app)
+            cleanup_task()
+            self.assertTrue(app.backend.cleanup.called)
+        finally:
+            app.backend = prev
+
+
+class test_group(Case):
+
+    def setUp(self):
+        self.prev = app.tasks.get("celery.group")
+        self.task = builtins.add_group_task(app)()
+
+    def tearDown(self):
+        app.tasks["celery.group"] = self.prev
+
+    def test_apply_async_eager(self):
+        self.task.apply = Mock()
+        app.conf.CELERY_ALWAYS_EAGER = True
+        try:
+            self.task.apply_async()
+        finally:
+            app.conf.CELERY_ALWAYS_EAGER = False
+        self.assertTrue(self.task.apply.called)
+
+    def test_apply(self):
+        x = group([add.s(4, 4), add.s(8, 8)])
+        x.name = self.task.name
+        res = x.apply()
+        self.assertEqual(res.get().join(), [8, 16])
+
+    def test_apply_async(self):
+        x = group([add.s(4, 4), add.s(8, 8)])
+        x.apply_async()
+
+    def test_apply_async_with_parent(self):
+        _tls.current_task = add
+        try:
+            x = group([add.s(4, 4), add.s(8, 8)])
+            x.apply_async()
+            self.assertTrue(add.request.children)
+        finally:
+            _tls.current_task = None
+
+
+class test_chain(Case):
+
+    def setUp(self):
+        self.prev = app.tasks.get("celery.chain")
+        self.task = builtins.add_chain_task(app)()
+
+    def tearDown(self):
+        app.tasks["celery.chain"] = self.prev
+
+    def test_apply_async(self):
+        c = add.s(2, 2) | add.s(4) | add.s(8)
+        result = c.apply_async()
+        self.assertTrue(result.parent)
+        self.assertTrue(result.parent.parent)
+        self.assertIsNone(result.parent.parent.parent)
+
+
+class test_chord(Case):
+
+    def setUp(self):
+        self.prev = app.tasks.get("celery.chord")
+        self.task = builtins.add_chain_task(app)()
+
+    def tearDown(self):
+        app.tasks["celery.chord"] = self.prev
+
+    def test_apply_async(self):
+        x = chord([add.s(i, i) for i in xrange(10)], body=xsum.s())
+        r = x.apply_async()
+        self.assertTrue(r)
+        self.assertTrue(r.parent)
+
+    def test_apply_eager(self):
+        app.conf.CELERY_ALWAYS_EAGER = True
+        try:
+            x = chord([add.s(i, i) for i in xrange(10)], body=xsum.s())
+            r = x.apply_async()
+            self.assertEqual(r.get(), 90)
+
+        finally:
+            app.conf.CELERY_ALWAYS_EAGER = False

+ 28 - 0
celery/tests/test_app/test_loaders.py

@@ -4,6 +4,8 @@ from __future__ import with_statement
 import os
 import os
 import sys
 import sys
 
 
+from mock import patch
+
 from celery import loaders
 from celery import loaders
 from celery.app import app_or_default
 from celery.app import app_or_default
 from celery.exceptions import (
 from celery.exceptions import (
@@ -13,6 +15,7 @@ from celery.exceptions import (
 from celery.loaders import base
 from celery.loaders import base
 from celery.loaders import default
 from celery.loaders import default
 from celery.loaders.app import AppLoader
 from celery.loaders.app import AppLoader
+from celery.utils.imports import NotAPackage
 
 
 from celery.tests.utils import AppCase, Case
 from celery.tests.utils import AppCase, Case
 from celery.tests.compat import catch_warnings
 from celery.tests.compat import catch_warnings
@@ -158,6 +161,31 @@ class TestDefaultLoader(Case):
         self.assertFalse(l.wanted_module_item("__FOO"))
         self.assertFalse(l.wanted_module_item("__FOO"))
         self.assertFalse(l.wanted_module_item("foo"))
         self.assertFalse(l.wanted_module_item("foo"))
 
 
+    @patch("celery.loaders.default.find_module")
+    def test_read_configuration_not_a_package(self, find_module):
+        find_module.side_effect = NotAPackage()
+        l = default.Loader()
+        with self.assertRaises(NotAPackage):
+            l.read_configuration()
+
+    @patch("celery.loaders.default.find_module")
+    def test_read_configuration_py_in_name(self, find_module):
+        prev = os.environ["CELERY_CONFIG_MODULE"]
+        os.environ["CELERY_CONFIG_MODULE"] = "celeryconfig.py"
+        try:
+            find_module.side_effect = NotAPackage()
+            l = default.Loader()
+            with self.assertRaises(NotAPackage):
+                l.read_configuration()
+        finally:
+            os.environ["CELERY_CONFIG_MODULE"] = prev
+
+    @patch("celery.loaders.default.find_module")
+    def test_read_configuration_importerror(self, find_module):
+        find_module.side_effect = ImportError()
+        l = default.Loader()
+        l.read_configuration()
+
     def test_read_configuration(self):
     def test_read_configuration(self):
         from types import ModuleType
         from types import ModuleType
 
 

+ 73 - 2
celery/tests/test_app/test_log.py

@@ -5,11 +5,13 @@ import sys
 import logging
 import logging
 from tempfile import mktemp
 from tempfile import mktemp
 
 
+from mock import patch, Mock
+
 from celery import current_app
 from celery import current_app
 from celery.app.log import Logging
 from celery.app.log import Logging
 from celery.utils.log import LoggingProxy
 from celery.utils.log import LoggingProxy
 from celery.utils import uuid
 from celery.utils import uuid
-from celery.utils.log import get_logger
+from celery.utils.log import get_logger, ColorFormatter, logger as base_logger
 from celery.tests.utils import (
 from celery.tests.utils import (
     Case, override_stdouts, wrap_logger, get_handlers,
     Case, override_stdouts, wrap_logger, get_handlers,
 )
 )
@@ -17,6 +19,58 @@ from celery.tests.utils import (
 log = current_app.log
 log = current_app.log
 
 
 
 
+class test_ColorFormatter(Case):
+
+    @patch("celery.utils.log.safe_str")
+    @patch("logging.Formatter.formatException")
+    def test_formatException_not_string(self, fe, safe_str):
+        x = ColorFormatter("HELLO")
+        value = KeyError()
+        fe.return_value = value
+        self.assertIs(x.formatException(value), value)
+        self.assertTrue(fe.called)
+        self.assertFalse(safe_str.called)
+
+    @patch("logging.Formatter.formatException")
+    @patch("celery.utils.log.safe_str")
+    def test_formatException_string(self, safe_str, fe, value="HELLO"):
+        x = ColorFormatter(value)
+        fe.return_value = value
+        self.assertTrue(x.formatException(value))
+        self.assertTrue(safe_str.called)
+
+    @patch("celery.utils.log.safe_str")
+    def test_format_raises(self, safe_str):
+        x = ColorFormatter("HELLO")
+
+        def on_safe_str(s):
+            try:
+                raise ValueError("foo")
+            finally:
+                safe_str.side_effect = None
+        safe_str.side_effect = on_safe_str
+
+        record = Mock()
+        record.levelname = "ERROR"
+        record.msg = "HELLO"
+        record.exc_text = "error text"
+        safe_str.return_value = record
+
+        x.format(record)
+        self.assertIn("<Unrepresentable", record.msg)
+        self.assertEqual(safe_str.call_count, 2)
+
+    @patch("celery.utils.log.safe_str")
+    def test_format_raises_no_color(self, safe_str):
+        x = ColorFormatter("HELLO", False)
+        record = Mock()
+        record.levelname = "ERROR"
+        record.msg = "HELLO"
+        record.exc_text = "error text"
+        x.format(record)
+        self.assertEqual(safe_str.call_count, 1)
+
+
 class test_default_logger(Case):
 class test_default_logger(Case):
 
 
     def setUp(self):
     def setUp(self):
@@ -24,6 +78,14 @@ class test_default_logger(Case):
         self.get_logger = lambda n=None: get_logger(n) if n else logging.root
         self.get_logger = lambda n=None: get_logger(n) if n else logging.root
         Logging._setup = False
         Logging._setup = False
 
 
+    def test_get_logger_sets_parent(self):
+        logger = get_logger("celery.test_get_logger")
+        self.assertEqual(logger.parent.name, base_logger.name)
+
+    def test_get_logger_root(self):
+        logger = get_logger(base_logger.name)
+        self.assertIs(logger.parent, logging.root)
+
     def test_setup_logging_subsystem_colorize(self):
     def test_setup_logging_subsystem_colorize(self):
         log.setup_logging_subsystem(colorize=None)
         log.setup_logging_subsystem(colorize=None)
         log.setup_logging_subsystem(colorize=True)
         log.setup_logging_subsystem(colorize=True)
@@ -57,7 +119,6 @@ class test_default_logger(Case):
         Logging._setup = False
         Logging._setup = False
         logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
         logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
                                    root=False, colorize=None)
                                    root=False, colorize=None)
-        print(logger.handlers)
         self.assertIs(get_handlers(logger)[0].stream, sys.__stderr__,
         self.assertIs(get_handlers(logger)[0].stream, sys.__stderr__,
                 "setup_logger logs to stderr without logfile argument.")
                 "setup_logger logs to stderr without logfile argument.")
 
 
@@ -112,6 +173,16 @@ class test_default_logger(Case):
             self.assertFalse(p.isatty())
             self.assertFalse(p.isatty())
             self.assertIsNone(p.fileno())
             self.assertIsNone(p.fileno())
 
 
+    def test_logging_proxy_recurse_protection(self):
+        logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
+                root=False)
+        p = LoggingProxy(logger, loglevel=logging.ERROR)
+        p._thread.recurse_protection = True
+        try:
+            self.assertIsNone(p.write("FOOFO"))
+        finally:
+            p._thread.recurse_protection = False
+
 
 
 class test_task_logger(test_default_logger):
 class test_task_logger(test_default_logger):
 
 

+ 5 - 10
celery/tests/test_bin/test_celeryd.py

@@ -19,15 +19,10 @@ from celery import current_app
 from celery.apps import worker as cd
 from celery.apps import worker as cd
 from celery.bin.celeryd import WorkerCommand, main as celeryd_main
 from celery.bin.celeryd import WorkerCommand, main as celeryd_main
 from celery.exceptions import ImproperlyConfigured, SystemTerminate
 from celery.exceptions import ImproperlyConfigured, SystemTerminate
+from celery.utils.log import ensure_process_aware_logger
 
 
-from celery.tests.utils import (
-    AppCase, WhateverIO, mask_modules,
-    reset_modules, skip_unless_module,
-    create_pidlock,
-)
-
+from celery.tests.utils import AppCase, WhateverIO, create_pidlock
 
 
-from celery.utils.log import ensure_process_aware_logger
 ensure_process_aware_logger()
 ensure_process_aware_logger()
 
 
 
 
@@ -416,7 +411,7 @@ class test_signal_handlers(AppCase):
     @disable_stdouts
     @disable_stdouts
     def test_worker_int_handler_only_stop_MainProcess(self):
     def test_worker_int_handler_only_stop_MainProcess(self):
         try:
         try:
-            import _multiprocessing
+            import _multiprocessing  # noqa
         except ImportError:
         except ImportError:
             raise SkipTest("only relevant for multiprocessing")
             raise SkipTest("only relevant for multiprocessing")
         process = current_process()
         process = current_process()
@@ -439,7 +434,7 @@ class test_signal_handlers(AppCase):
     @disable_stdouts
     @disable_stdouts
     def test_worker_term_hard_handler_only_stop_MainProcess(self):
     def test_worker_term_hard_handler_only_stop_MainProcess(self):
         try:
         try:
-            import _multiprocessing
+            import _multiprocessing  # noqa
         except ImportError:
         except ImportError:
             raise SkipTest("only relevant for multiprocessing")
             raise SkipTest("only relevant for multiprocessing")
         process = current_process()
         process = current_process()
@@ -477,7 +472,7 @@ class test_signal_handlers(AppCase):
     @disable_stdouts
     @disable_stdouts
     def test_worker_term_handler_only_stop_MainProcess(self):
     def test_worker_term_handler_only_stop_MainProcess(self):
         try:
         try:
-            import _multiprocessing
+            import _multiprocessing  # noqa
         except ImportError:
         except ImportError:
             raise SkipTest("only relevant for multiprocessing")
             raise SkipTest("only relevant for multiprocessing")
         process = current_process()
         process = current_process()

+ 158 - 0
celery/tests/test_task/test_canvas.py

@@ -0,0 +1,158 @@
+from __future__ import absolute_import
+from __future__ import with_statement
+
+from mock import Mock
+
+from celery import task
+from celery.canvas import Signature, chain, group, chord, subtask
+
+from celery.tests.utils import Case
+
+SIG = Signature({"task": "TASK",
+                 "args": ("A1", ),
+                 "kwargs": {"K1": "V1"},
+                 "options": {"task_id": "TASK_ID"},
+                 "subtask_type": ""})
+
+
+@task
+def add(x, y):
+    return x + y
+
+
+@task
+def mul(x, y):
+    return x * y
+
+
+@task
+def div(x, y):
+    return x / y
+
+
+class test_Signature(Case):
+
+    def test_getitem_property_class(self):
+        self.assertTrue(Signature.task)
+        self.assertTrue(Signature.args)
+        self.assertTrue(Signature.kwargs)
+        self.assertTrue(Signature.options)
+        self.assertTrue(Signature.subtask_type)
+
+    def test_getitem_property(self):
+        self.assertEqual(SIG.task, "TASK")
+        self.assertEqual(SIG.args, ("A1", ))
+        self.assertEqual(SIG.kwargs, {"K1": "V1"})
+        self.assertEqual(SIG.options, {"task_id": "TASK_ID"})
+        self.assertEqual(SIG.subtask_type, "")
+
+    def test_replace(self):
+        x = Signature("TASK", ("A"), {})
+        self.assertTupleEqual(x.replace(args=("B", )).args, ("B", ))
+        self.assertDictEqual(x.replace(kwargs={"FOO": "BAR"}).kwargs,
+                {"FOO": "BAR"})
+        self.assertDictEqual(x.replace(options={"task_id": "123"}).options,
+                {"task_id": "123"})
+
+    def test_set(self):
+        self.assertDictEqual(Signature("TASK", x=1).set(task_id="2").options,
+                {"x": 1, "task_id": "2"})
+
+    def test_link(self):
+        x = subtask(SIG)
+        x.link(SIG)
+        x.link(SIG)
+        self.assertIn(SIG, x.options["link"])
+        self.assertEqual(len(x.options["link"]), 1)
+
+    def test_link_error(self):
+        x = subtask(SIG)
+        x.link_error(SIG)
+        x.link_error(SIG)
+        self.assertIn(SIG, x.options["link_error"])
+        self.assertEqual(len(x.options["link_error"]), 1)
+
+    def test_flatten_links(self):
+        tasks = [add.s(2, 2), mul.s(4), div.s(2)]
+        tasks[0].link(tasks[1])
+        tasks[1].link(tasks[2])
+        self.assertEqual(tasks[0].flatten_links(), tasks)
+
+    def test_OR(self):
+        x = add.s(2, 2) | mul.s(4)
+        self.assertIsInstance(x, chain)
+        y = add.s(4, 4) | div.s(2)
+        z = x | y
+        self.assertIsInstance(y, chain)
+        self.assertIsInstance(z, chain)
+        self.assertEqual(len(z.tasks), 4)
+        with self.assertRaises(TypeError):
+            x | 10
+
+    def test_INVERT(self):
+        x = add.s(2, 2)
+        x.apply_async = Mock()
+        x.apply_async.return_value = Mock()
+        x.apply_async.return_value.get = Mock()
+        x.apply_async.return_value.get.return_value = 4
+        self.assertEqual(~x, 4)
+        self.assertTrue(x.apply_async.called)
+
+
+class test_chain(Case):
+
+    def test_repr(self):
+        x = add.s(2, 2) | add.s(2)
+        self.assertEqual(repr(x), '%s(2, 2) | %s(2)' % (add.name, add.name))
+
+    def test_reverse(self):
+        x = add.s(2, 2) | add.s(2)
+        self.assertIsInstance(subtask(x), chain)
+        self.assertIsInstance(subtask(dict(x)), chain)
+
+
+class test_group(Case):
+
+    def test_repr(self):
+        x = group([add.s(2, 2), add.s(4, 4)])
+        self.assertEqual(repr(x), repr(x.tasks))
+
+    def test_reverse(self):
+        x = group([add.s(2, 2), add.s(4, 4)])
+        self.assertIsInstance(subtask(x), group)
+        self.assertIsInstance(subtask(dict(x)), group)
+
+
+class test_chord(Case):
+
+    def test_reverse(self):
+        x = chord([add.s(2, 2), add.s(4, 4)], body=mul.s(4))
+        self.assertIsInstance(subtask(x), chord)
+        self.assertIsInstance(subtask(dict(x)), chord)
+
+    def test_clone_clones_body(self):
+        x = chord([add.s(2, 2), add.s(4, 4)], body=mul.s(4))
+        y = x.clone()
+        self.assertIsNot(x.kwargs["body"], y.kwargs["body"])
+        y.kwargs.pop("body")
+        z = y.clone()
+        self.assertIsNone(z.kwargs.get("body"))
+
+    def test_links_to_body(self):
+        x = chord([add.s(2, 2), add.s(4, 4)], body=mul.s(4))
+        x.link(div.s(2))
+        self.assertFalse(x.options.get("link"))
+        self.assertTrue(x.kwargs["body"].options["link"])
+
+        x.link_error(div.s(2))
+        self.assertFalse(x.options.get("link_error"))
+        self.assertTrue(x.kwargs["body"].options["link_error"])
+
+        self.assertTrue(x.tasks)
+        self.assertTrue(x.body)
+
+    def test_repr(self):
+        x = chord([add.s(2, 2), add.s(4, 4)], body=mul.s(4))
+        self.assertTrue(repr(x))
+        x.kwargs["body"] = None
+        self.assertIn("without body", repr(x))

+ 0 - 10
celery/tests/test_task/test_task_builtins.py

@@ -1,10 +0,0 @@
-from __future__ import absolute_import
-
-from celery.task import backend_cleanup
-from celery.tests.utils import Case
-
-
-class test_backend_cleanup(Case):
-
-    def test_run(self):
-        backend_cleanup.apply()

+ 38 - 2
celery/tests/test_task/test_task_sets.py

@@ -3,7 +3,9 @@ from __future__ import with_statement
 
 
 import anyjson
 import anyjson
 
 
-from celery.app import app_or_default
+from mock import Mock
+
+from celery import current_app
 from celery.task import Task
 from celery.task import Task
 from celery.task.sets import subtask, TaskSet
 from celery.task.sets import subtask, TaskSet
 from celery.canvas import Signature
 from celery.canvas import Signature
@@ -105,7 +107,7 @@ class test_TaskSet(Case):
         self.assertEqual(len(ts), 3)
         self.assertEqual(len(ts), 3)
 
 
     def test_respects_ALWAYS_EAGER(self):
     def test_respects_ALWAYS_EAGER(self):
-        app = app_or_default()
+        app = current_app
 
 
         class MockTaskSet(TaskSet):
         class MockTaskSet(TaskSet):
             applied = 0
             applied = 0
@@ -143,6 +145,25 @@ class test_TaskSet(Case):
 
 
         ts.apply_async(publisher=Publisher())
         ts.apply_async(publisher=Publisher())
 
 
+        # setting current_task
+
+        @current_app.task
+        def xyz():
+            pass
+        from celery.app.state import _tls
+        _tls.current_task = xyz
+        try:
+            ts.apply_async(publisher=Publisher())
+        finally:
+            _tls.current_task = None
+            xyz.request.clear()
+
+        # must close publisher
+        ts._Publisher = Mock()
+        ts._Publisher.return_value = Mock()
+        ts.apply_async()
+        self.assertTrue(ts._Publisher.return_value.close.called)
+
     def test_apply(self):
     def test_apply(self):
 
 
         applied = [0]
         applied = [0]
@@ -156,3 +177,18 @@ class test_TaskSet(Case):
                         for i in (2, 4, 8)])
                         for i in (2, 4, 8)])
         ts.apply()
         ts.apply()
         self.assertEqual(applied[0], 3)
         self.assertEqual(applied[0], 3)
+
+    def test_set_app(self):
+        ts = TaskSet([])
+        ts.app = 42
+        self.assertEqual(ts._app, 42)
+
+    def test_set_tasks(self):
+        ts = TaskSet([])
+        ts.tasks = [1, 2, 3]
+        self.assertEqual(ts.data, [1, 2, 3])
+
+    def test_set_Publisher(self):
+        ts = TaskSet([])
+        ts.Publisher = 42
+        self.assertEqual(ts._Publisher, 42)

+ 58 - 0
celery/tests/test_utils/test_compat.py

@@ -0,0 +1,58 @@
+from __future__ import absolute_import
+
+
+import celery
+from celery.task.base import Task
+
+from celery.tests.utils import Case
+
+
+class test_MagicModule(Case):
+
+    def test_class_property_set_without_type(self):
+        self.assertTrue(Task.__dict__["app"].__get__(Task()))
+
+    def test_class_property_set_on_class(self):
+        self.assertIs(Task.__dict__["app"].__set__(None, None),
+                      Task.__dict__["app"])
+
+    def test_class_property_set(self):
+
+        class X(Task):
+            pass
+
+        app = celery.Celery(set_as_current=False)
+        Task.__dict__["app"].__set__(X(), app)
+        self.assertEqual(X.app, app)
+
+    def test_dir(self):
+        self.assertTrue(dir(celery.messaging))
+
+    def test_direct(self):
+        import sys
+        prev_celery = sys.modules.pop("celery", None)
+        prev_task = sys.modules.pop("celery.task", None)
+        try:
+            import celery
+            self.assertTrue(celery.task)
+        finally:
+            sys.modules["celery"] = prev_celery
+            sys.modules["celery.task"] = prev_task
+
+    def test_app_attrs(self):
+        self.assertEqual(celery.task.control.broadcast,
+                         celery.current_app.control.broadcast)
+
+    def test_decorators_task(self):
+        @celery.decorators.task
+        def _test_decorators_task():
+            pass
+
+        self.assertTrue(_test_decorators_task.accept_magic_kwargs)
+
+    def test_decorators_periodic_task(self):
+        @celery.decorators.periodic_task(run_every=3600)
+        def _test_decorators_ptask():
+            pass
+
+        self.assertTrue(_test_decorators_ptask.accept_magic_kwargs)

+ 276 - 0
celery/tests/test_utils/test_local.py

@@ -0,0 +1,276 @@
+from __future__ import absolute_import
+from __future__ import with_statement
+
+from celery.local import Proxy, PromiseProxy, maybe_evaluate, try_import
+
+from celery.tests.utils import Case
+
+
+class test_try_import(Case):
+
+    def test_imports(self):
+        self.assertTrue(try_import(__name__))
+
+    def test_when_default(self):
+        default = object()
+        self.assertIs(try_import("foobar.awqewqe.asdwqewq", default), default)
+
+
+class test_Proxy(Case):
+
+    def test_name(self):
+
+        def real():
+            """real function"""
+            return "REAL"
+
+        x = Proxy(lambda: real, name="xyz")
+        self.assertEqual(x.__name__, "xyz")
+
+        y = Proxy(lambda: real)
+        self.assertEqual(y.__name__, "real")
+
+        self.assertEqual(x.__doc__, "real function")
+
+        self.assertEqual(x.__class__, type(real))
+        self.assertEqual(x.__dict__, real.__dict__)
+        self.assertEqual(repr(x), repr(real))
+
+    def test_nonzero(self):
+
+        class X(object):
+
+            def __nonzero__(self):
+                return False
+
+        x = Proxy(lambda: X())
+        self.assertFalse(x)
+
+    def test_slots(self):
+
+        class X(object):
+            __slots__ = ()
+
+        x = Proxy(X)
+        with self.assertRaises(AttributeError):
+            x.__dict__
+
+    def test_unicode(self):
+
+        class X(object):
+
+            def __unicode__(self):
+                return u"UNICODE"
+
+            def __repr__(self):
+                return "REPR"
+
+        x = Proxy(lambda: X())
+        self.assertEqual(unicode(x), u"UNICODE")
+        del(X.__unicode__)
+        self.assertEqual(unicode(x), "REPR")
+
+    def test_dir(self):
+
+        class X(object):
+
+            def __dir__(self):
+                return ["a", "b", "c"]
+
+        x = Proxy(lambda: X())
+        self.assertListEqual(dir(x), ["a", "b", "c"])
+
+        class Y(object):
+
+            def __dir__(self):
+                raise RuntimeError()
+        y = Proxy(lambda: Y())
+        self.assertListEqual(dir(y), [])
+
+    def test_getsetdel_attr(self):
+
+        class X(object):
+            a = 1
+            b = 2
+            c = 3
+
+            def __dir__(self):
+                return ["a", "b", "c"]
+
+        v = X()
+
+        x = Proxy(lambda: v)
+        self.assertListEqual(x.__members__, ["a", "b", "c"])
+        self.assertEqual(x.a, 1)
+        self.assertEqual(x.b, 2)
+        self.assertEqual(x.c, 3)
+
+        setattr(x, "a", 10)
+        self.assertEqual(x.a, 10)
+
+        del(x.a)
+        self.assertEqual(x.a, 1)
+
+    def test_dictproxy(self):
+        v = {}
+        x = Proxy(lambda: v)
+        x["foo"] = 42
+        self.assertEqual(x["foo"], 42)
+        self.assertEqual(len(x), 1)
+        self.assertIn("foo", x)
+        del(x["foo"])
+        with self.assertRaises(KeyError):
+            x["foo"]
+        self.assertTrue(iter(x))
+
+    def test_listproxy(self):
+        v = []
+        x = Proxy(lambda: v)
+        x.append(1)
+        x.extend([2, 3, 4])
+        self.assertEqual(x[0], 1)
+        self.assertEqual(x[:-1], [1, 2, 3])
+        del(x[-1])
+        self.assertEqual(x[:-1], [1, 2])
+        x[0] = 10
+        self.assertEqual(x[0], 10)
+        self.assertIn(10, x)
+        self.assertEqual(len(x), 3)
+        self.assertTrue(iter(x))
+
+    def test_int(self):
+        self.assertEqual(Proxy(lambda: 10) + 1, Proxy(lambda: 11))
+        self.assertEqual(Proxy(lambda: 10) - 1, Proxy(lambda: 9))
+        self.assertEqual(Proxy(lambda: 10) * 2, Proxy(lambda: 20))
+        self.assertEqual(Proxy(lambda: 10) ** 2, Proxy(lambda: 100))
+        self.assertEqual(Proxy(lambda: 20) / 2, Proxy(lambda: 10))
+        self.assertEqual(Proxy(lambda: 20) // 2, Proxy(lambda: 10))
+        self.assertEqual(Proxy(lambda: 11) % 2, Proxy(lambda: 1))
+        self.assertEqual(Proxy(lambda: 10) << 2, Proxy(lambda: 40))
+        self.assertEqual(Proxy(lambda: 10) >> 2, Proxy(lambda: 2))
+        self.assertEqual(Proxy(lambda: 10) ^ 7, Proxy(lambda: 13))
+        self.assertEqual(Proxy(lambda: 10) | 40, Proxy(lambda: 42))
+        self.assertEqual(~Proxy(lambda: 10), Proxy(lambda: -11))
+        self.assertEqual(-Proxy(lambda: 10), Proxy(lambda: -10))
+        self.assertEqual(+Proxy(lambda: -10), Proxy(lambda: -10))
+        self.assertTrue(Proxy(lambda: 10) < Proxy(lambda: 20))
+        self.assertTrue(Proxy(lambda: 20) > Proxy(lambda: 10))
+        self.assertTrue(Proxy(lambda: 10) >= Proxy(lambda: 10))
+        self.assertTrue(Proxy(lambda: 10) <= Proxy(lambda: 10))
+        self.assertTrue(Proxy(lambda: 10) == Proxy(lambda: 10))
+        self.assertTrue(Proxy(lambda: 20) != Proxy(lambda: 10))
+
+        x = Proxy(lambda: 10)
+        x -= 1
+        self.assertEqual(x, 9)
+        x = Proxy(lambda: 9)
+        x += 1
+        self.assertEqual(x, 10)
+        x = Proxy(lambda: 10)
+        x *= 2
+        self.assertEqual(x, 20)
+        x = Proxy(lambda: 20)
+        x /= 2
+        self.assertEqual(x, 10)
+        x = Proxy(lambda: 10)
+        x %= 2
+        self.assertEqual(x, 0)
+        x = Proxy(lambda: 10)
+        x <<= 3
+        self.assertEqual(x, 80)
+        x = Proxy(lambda: 80)
+        x >>= 4
+        self.assertEqual(x, 5)
+        x = Proxy(lambda: 5)
+        x ^= 1
+        self.assertEqual(x, 4)
+        x = Proxy(lambda: 4)
+        x **= 4
+        self.assertEqual(x, 256)
+        x = Proxy(lambda: 256)
+        x //= 2
+        self.assertEqual(x, 128)
+        x = Proxy(lambda: 128)
+        x |= 2
+        self.assertEqual(x, 130)
+        x = Proxy(lambda: 130)
+        x &= 10
+        self.assertEqual(x, 2)
+
+        x = Proxy(lambda: 10)
+        self.assertEqual(type(x.__float__()), float)
+        self.assertEqual(type(x.__int__()), int)
+        self.assertEqual(type(x.__long__()), long)
+        self.assertTrue(hex(x))
+        self.assertTrue(oct(x))
+
+    def test_hash(self):
+
+        class X(object):
+
+            def __hash__(self):
+                return 1234
+
+        self.assertEqual(hash(Proxy(lambda: X())), 1234)
+
+    def test_call(self):
+
+        class X(object):
+
+            def __call__(self):
+                return 1234
+
+        self.assertEqual(Proxy(lambda: X())(), 1234)
+
+    def test_context(self):
+
+        class X(object):
+            entered = exited = False
+
+            def __enter__(self):
+                self.entered = True
+                return 1234
+
+            def __exit__(self, *exc_info):
+                self.exited = True
+
+        v = X()
+        x = Proxy(lambda: v)
+        with x as val:
+            self.assertEqual(val, 1234)
+        self.assertTrue(x.entered)
+        self.assertTrue(x.exited)
+
+    def test_reduce(self):
+
+        class X(object):
+
+            def __reduce__(self):
+                return 123
+
+        x = Proxy(lambda: X())
+        self.assertEqual(x.__reduce__(), 123)
+
+
+class test_PromiseProxy(Case):
+
+    def test_only_evaluated_once(self):
+
+        class X(object):
+            attr = 123
+            evals = 0
+
+            def __init__(self):
+                self.__class__.evals += 1
+
+        p = PromiseProxy(X)
+        self.assertEqual(p.attr, 123)
+        self.assertEqual(p.attr, 123)
+        self.assertEqual(X.evals, 1)
+
+    def test_maybe_evaluate(self):
+        x = PromiseProxy(lambda: 30)
+        self.assertEqual(maybe_evaluate(x), 30)
+        self.assertEqual(maybe_evaluate(x), 30)
+
+        self.assertEqual(maybe_evaluate(30), 30)

+ 0 - 4
celery/tests/test_utils/test_timer2.py

@@ -48,14 +48,10 @@ class test_Schedule(Case):
 
 
         timer2.mktime = _overflow
         timer2.mktime = _overflow
         try:
         try:
-            print("+S1")
             s.enter(timer2.Entry(lambda: None, (), {}),
             s.enter(timer2.Entry(lambda: None, (), {}),
                     eta=datetime.now())
                     eta=datetime.now())
-            print("-S1")
-            print("+S2")
             s.enter(timer2.Entry(lambda: None, (), {}),
             s.enter(timer2.Entry(lambda: None, (), {}),
                     eta=None)
                     eta=None)
-            print("-S2")
             s.on_error = None
             s.on_error = None
             with self.assertRaises(OverflowError):
             with self.assertRaises(OverflowError):
                 s.enter(timer2.Entry(lambda: None, (), {}),
                 s.enter(timer2.Entry(lambda: None, (), {}),

+ 36 - 1
celery/tests/test_utils/test_utils_imports.py

@@ -1,6 +1,13 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
 
 
-from celery.utils.imports import qualname, symbol_by_name
+from mock import Mock, patch
+
+from celery.utils.imports import (
+    qualname,
+    symbol_by_name,
+    reload_from_cwd,
+    module_file,
+)
 
 
 from celery.tests.utils import Case
 from celery.tests.utils import Case
 
 
@@ -15,3 +22,31 @@ class test_import_utils(Case):
     def test_symbol_by_name__instance_returns_instance(self):
     def test_symbol_by_name__instance_returns_instance(self):
         instance = object()
         instance = object()
         self.assertIs(symbol_by_name(instance), instance)
         self.assertIs(symbol_by_name(instance), instance)
+
+    def test_symbol_by_name_returns_default(self):
+        default = object()
+        self.assertIs(symbol_by_name("xyz.ryx.qedoa.weq:foz",
+                        default=default), default)
+
+    def test_symbol_by_name_package(self):
+        from celery.worker import WorkController
+        self.assertIs(symbol_by_name(".worker:WorkController",
+                    package="celery"), WorkController)
+
+    @patch("celery.utils.imports.reload")
+    def test_reload_from_cwd(self, reload):
+        reload_from_cwd("foo")
+        self.assertTrue(reload.called)
+
+    def test_reload_from_cwd_custom_reloader(self):
+        reload = Mock()
+        reload_from_cwd("foo", reload)
+        self.assertTrue(reload.called)
+
+    def test_module_file(self):
+        m1 = Mock()
+        m1.__file__ = "/opt/foo/xyz.pyc"
+        self.assertEqual(module_file(m1), "/opt/foo/xyz.py")
+        m2 = Mock()
+        m2.__file__ = "/opt/foo/xyz.py"
+        self.assertEqual(module_file(m1), "/opt/foo/xyz.py")

+ 0 - 1
celery/tests/utils.py

@@ -434,7 +434,6 @@ platform_pyimp = partial(
 )
 )
 
 
 
 
-
 @contextmanager
 @contextmanager
 def sys_platform(value):
 def sys_platform(value):
     prev, sys.platform = sys.platform, value
     prev, sys.platform = sys.platform, value

+ 1 - 3
celery/utils/__init__.py

@@ -22,12 +22,9 @@ from functools import partial, wraps
 from inspect import getargspec
 from inspect import getargspec
 from pprint import pprint
 from pprint import pprint
 
 
-from billiard.util import register_after_fork
-
 from celery.exceptions import CPendingDeprecationWarning, CDeprecationWarning
 from celery.exceptions import CPendingDeprecationWarning, CDeprecationWarning
 from .compat import StringIO
 from .compat import StringIO
 
 
-from .imports import symbol_by_name, qualname
 from .functional import noop
 from .functional import noop
 
 
 PENDING_DEPRECATION_FMT = """
 PENDING_DEPRECATION_FMT = """
@@ -61,6 +58,7 @@ def deprecated(description=None, deprecation=None, removal=None,
 
 
         @wraps(fun)
         @wraps(fun)
         def __inner(*args, **kwargs):
         def __inner(*args, **kwargs):
+            from .imports import qualname
             warn_deprecated(description=description or qualname(fun),
             warn_deprecated(description=description or qualname(fun),
                             deprecation=deprecation,
                             deprecation=deprecation,
                             removal=removal,
                             removal=removal,

+ 2 - 2
celery/utils/imports.py

@@ -15,7 +15,7 @@ class NotAPackage(Exception):
     pass
     pass
 
 
 
 
-if sys.version_info >= (3, 3):
+if sys.version_info >= (3, 3):  # pragma: no cover
 
 
     def qualname(obj):
     def qualname(obj):
         return obj.__qualname__
         return obj.__qualname__
@@ -109,7 +109,7 @@ def cwd_in_path():
         finally:
         finally:
             try:
             try:
                 sys.path.remove(cwd)
                 sys.path.remove(cwd)
-            except ValueError:
+            except ValueError:  # pragma: no cover
                 pass
                 pass
 
 
 
 

+ 18 - 25
celery/utils/log.py

@@ -21,14 +21,14 @@ is_py3k = sys.version_info[0] == 3
 # Every logger in the celery package inherits from the "celery"
 # Every logger in the celery package inherits from the "celery"
 # logger, and every task logger inherits from the "celery.task"
 # logger, and every task logger inherits from the "celery.task"
 # logger.
 # logger.
-logger = _get_logger("celery")
+base_logger = logger = _get_logger("celery")
 mp_logger = _get_logger("multiprocessing")
 mp_logger = _get_logger("multiprocessing")
 
 
 
 
 def get_logger(name):
 def get_logger(name):
     l = _get_logger(name)
     l = _get_logger(name)
-    if l.parent is logging.root and l is not logger:
-        l.parent = logger
+    if logging.root not in (l, l.parent) and l is not base_logger:
+        l.parent = base_logger
     return l
     return l
 task_logger = get_logger("celery.task")
 task_logger = get_logger("celery.task")
 
 
@@ -163,34 +163,27 @@ class LoggingProxy(object):
         return None
         return None
 
 
 
 
-def _patch_logger_class():
+def ensure_process_aware_logger():
     """Make sure process name is recorded when loggers are used."""
     """Make sure process name is recorded when loggers are used."""
-    logging._acquireLock()
-    try:
-        OldLoggerClass = logging.getLoggerClass()
-        if not getattr(OldLoggerClass, '_process_aware', False):
-
-            class ProcessAwareLogger(OldLoggerClass):
+    global _process_aware
+    if not _process_aware:
+        logging._acquireLock()
+        try:
+            _process_aware = True
+            Logger = logging.getLoggerClass()
+            if getattr(Logger, '_process_aware', False):  # pragma: no cover
+                return
+
+            class ProcessAwareLogger(Logger):
                 _process_aware = True
                 _process_aware = True
 
 
                 def makeRecord(self, *args, **kwds):
                 def makeRecord(self, *args, **kwds):
-                    record = OldLoggerClass.makeRecord(self, *args, **kwds)
-                    if current_process:
-                        record.processName = current_process()._name
-                    else:
-                        record.processName = ""
+                    record = Logger.makeRecord(self, *args, **kwds)
+                    record.processName = current_process()._name
                     return record
                     return record
             logging.setLoggerClass(ProcessAwareLogger)
             logging.setLoggerClass(ProcessAwareLogger)
-    finally:
-        logging._releaseLock()
-
-
-def ensure_process_aware_logger():
-    global _process_aware
-
-    if not _process_aware:
-        _patch_logger_class()
-        _process_aware = True
+        finally:
+            logging._releaseLock()
 
 
 
 
 def get_multiprocessing_logger():
 def get_multiprocessing_logger():

+ 0 - 6
setup.cfg

@@ -9,17 +9,11 @@ cover3-exclude = celery
                  celery.bin.celeryd_detach
                  celery.bin.celeryd_detach
                  celery.bin.celeryctl
                  celery.bin.celeryctl
                  celery.bin.camqadm
                  celery.bin.camqadm
-                 celery.execute
-                 celery.local
                  celery.platforms
                  celery.platforms
-                 celery.utils.encoding
-                 celery.utils.patch
                  celery.utils.compat
                  celery.utils.compat
                  celery.utils.mail
                  celery.utils.mail
-                 celery.utils.functional
                  celery.utils.dispatch*
                  celery.utils.dispatch*
                  celery.utils.term
                  celery.utils.term
-                 celery.messaging
                  celery.db.a805d4bd
                  celery.db.a805d4bd
                  celery.db.dfd042c7
                  celery.db.dfd042c7
                  celery.contrib*
                  celery.contrib*