Forráskód Böngészése

Huge test refactoring so that the unit tests no longer uses the current app.

- During test runs the current_app and default_app is set to an object
  that will crash if any attribute is accessed.

    For tests that require the current_app to exist (like pickle tests and the
    tests in compat_modules) there is a now decorator that can be used to allow
    it:

        from celery.tests.case import AppCase, depends_on_current_app

        class test_Example(AppCase):

            @depends_on_current_app
            def test_pickle_result(self):
                pickle.loads(pickle.dumps(self.app.AsyncResult))

- There is no longer any default configuration set when the test suite runs.

    The test suite used to setup the Celery global environment by
    setting the CELERY_CONFIG_MODULE environment variable to the module
    ``celery.tests.config``.  This will no longer happen, so tests can no
    longer use the global default_app.   See below.

- AppCase.app is now created and destroyed for every test function.

- Unit tests should now use a custom app constructor: AppCase.Celery

    This will return an app object that is configured to be used for
    testing.  E.g. the broker is set to memory:// and the backend
    is set to cache+memory://, and any other configuration that
    makes it suitable for testing::

        from celery.tests.case import AppCase

        class test_Example(AppCase):

            def test_foo(self):
                with self.Celery() as app:
                    ...

    It's rarely necessary to create a new app though, as the default app
    set up in AppCase (`self.app`) is usually sufficient.  Before you would
    have to create new apps if you wanted to make changes to the app
    without causing side effects for later tests, but now the `self.app`
    is destroyed and recreated for every test.

    Note that this app will have the ``set_as_current`` argument set to False.
Ask Solem 11 éve
szülő
commit
ab83c9ef75
74 módosított fájl, 1326 hozzáadás és 2263 törlés
  1. 15 14
      celery/tests/__init__.py
  2. 1 1
      celery/tests/app/test_amqp.py
  3. 2 2
      celery/tests/app/test_annotations.py
  4. 100 147
      celery/tests/app/test_app.py
  5. 34 37
      celery/tests/app/test_beat.py
  6. 3 3
      celery/tests/app/test_builtins.py
  7. 2 2
      celery/tests/app/test_celery.py
  8. 3 9
      celery/tests/app/test_control.py
  9. 4 4
      celery/tests/app/test_defaults.py
  10. 2 2
      celery/tests/app/test_exceptions.py
  11. 27 36
      celery/tests/app/test_loaders.py
  12. 8 11
      celery/tests/app/test_log.py
  13. 34 48
      celery/tests/app/test_registry.py
  14. 62 70
      celery/tests/app/test_routes.py
  15. 2 2
      celery/tests/app/test_schedules.py
  16. 2 2
      celery/tests/app/test_utils.py
  17. 7 8
      celery/tests/backends/test_amqp.py
  18. 13 11
      celery/tests/backends/test_backends.py
  19. 10 14
      celery/tests/backends/test_base.py
  20. 7 13
      celery/tests/backends/test_cache.py
  21. 22 25
      celery/tests/backends/test_cassandra.py
  22. 46 58
      celery/tests/backends/test_couchbase.py
  23. 20 22
      celery/tests/backends/test_database.py
  24. 6 8
      celery/tests/backends/test_mongodb.py
  25. 14 21
      celery/tests/backends/test_redis.py
  26. 0 2
      celery/tests/bin/test_amqp.py
  27. 16 16
      celery/tests/bin/test_base.py
  28. 5 7
      celery/tests/bin/test_beat.py
  29. 30 25
      celery/tests/bin/test_celery.py
  30. 10 9
      celery/tests/bin/test_celeryd_detach.py
  31. 4 4
      celery/tests/bin/test_celeryevdump.py
  32. 6 6
      celery/tests/bin/test_multi.py
  33. 53 74
      celery/tests/bin/test_worker.py
  34. 128 47
      celery/tests/case.py
  35. 63 10
      celery/tests/compat_modules/test_compat.py
  36. 6 7
      celery/tests/compat_modules/test_compat_utils.py
  37. 4 3
      celery/tests/compat_modules/test_decorators.py
  38. 13 10
      celery/tests/compat_modules/test_http.py
  39. 3 2
      celery/tests/compat_modules/test_messaging.py
  40. 88 42
      celery/tests/compat_modules/test_sets.py
  41. 2 2
      celery/tests/concurrency/test_concurrency.py
  42. 4 4
      celery/tests/concurrency/test_eventlet.py
  43. 7 7
      celery/tests/concurrency/test_gevent.py
  44. 3 3
      celery/tests/concurrency/test_pool.py
  45. 2 2
      celery/tests/concurrency/test_solo.py
  46. 2 2
      celery/tests/concurrency/test_threads.py
  47. 0 44
      celery/tests/config.py
  48. 24 26
      celery/tests/contrib/test_abortable.py
  49. 1 1
      celery/tests/contrib/test_methods.py
  50. 5 5
      celery/tests/contrib/test_migrate.py
  51. 10 13
      celery/tests/events/test_events.py
  52. 4 4
      celery/tests/events/test_state.py
  53. 5 7
      celery/tests/fixups/test_django.py
  54. 13 9
      celery/tests/functional/case.py
  55. 25 37
      celery/tests/security/test_security.py
  56. 13 20
      celery/tests/tasks/test_canvas.py
  57. 13 19
      celery/tests/tasks/test_chord.py
  58. 3 3
      celery/tests/tasks/test_context.py
  59. 16 16
      celery/tests/tasks/test_result.py
  60. 49 815
      celery/tests/tasks/test_tasks.py
  61. 29 26
      celery/tests/tasks/test_trace.py
  62. 7 8
      celery/tests/utils/test_dispatcher.py
  63. 22 8
      celery/tests/utils/test_saferef.py
  64. 4 6
      celery/tests/utils/test_text.py
  65. 6 6
      celery/tests/worker/test_bootsteps.py
  66. 14 27
      celery/tests/worker/test_consumer.py
  67. 59 76
      celery/tests/worker/test_control.py
  68. 2 2
      celery/tests/worker/test_heartbeat.py
  69. 1 1
      celery/tests/worker/test_loops.py
  70. 27 30
      celery/tests/worker/test_request.py
  71. 2 2
      celery/tests/worker/test_revoke.py
  72. 8 15
      celery/tests/worker/test_state.py
  73. 9 12
      celery/tests/worker/test_strategy.py
  74. 60 171
      celery/tests/worker/test_worker.py

+ 15 - 14
celery/tests/__init__.py

@@ -14,20 +14,16 @@ except NameError:
     class WindowsError(Exception):
         pass
 
-config_module = os.environ.setdefault(
-    'CELERY_TEST_CONFIG_MODULE', 'celery.tests.config',
+os.environ.update(
+    #: warn if config module not found
+    C_WNOCONF='yes',
+    EVENTLET_NOPATCH='yes',
+    GEVENT_NOPATCH='yes',
+    KOMBU_DISABLE_LIMIT_PROTECTION='yes',
+    # virtual.QoS will not do sanity assertions when this is set.
+    KOMBU_UNITTEST='yes',
 )
 
-os.environ.setdefault('CELERY_CONFIG_MODULE', config_module)
-os.environ['CELERY_LOADER'] = 'default'
-os.environ['EVENTLET_NOPATCH'] = 'yes'
-os.environ['GEVENT_NOPATCH'] = 'yes'
-os.environ['KOMBU_DISABLE_LIMIT_PROTECTION'] = 'yes'
-os.environ['CELERY_BROKER_URL'] = 'memory://'
-
-# virtual.QoS will not do sanity assertions when this is set.
-os.environ['KOMBU_UNITTEST'] = 'yes'
-
 
 def setup():
     if os.environ.get('COVER_ALL_MODULES') or '--with-coverage3' in sys.argv:
@@ -35,6 +31,9 @@ def setup():
         with catch_warnings(record=True):
             import_all_modules()
         warnings.resetwarnings()
+    from celery.tests.case import Trap
+    from celery._state import set_default_app
+    set_default_app(Trap())
 
 
 def teardown():
@@ -81,9 +80,11 @@ def find_distribution_modules(name=__name__, file=__file__):
 
 
 def import_all_modules(name=__name__, file=__file__,
-                       skip=['celery.decorators', 'celery.contrib.batches']):
+                       skip=('celery.decorators',
+                             'celery.contrib.batches',
+                             'celery.task')):
     for module in find_distribution_modules(name, file):
-        if module not in skip:
+        if not module.startswith(skip):
             try:
                 import_module(module)
             except ImportError:

+ 1 - 1
celery/tests/app/test_amqp.py

@@ -81,7 +81,7 @@ class test_compat_TaskPublisher(AppCase):
         self.assertEqual(producer.exchange.type, 'topic')
 
     def test_compat_exchange_is_Exchange(self):
-        producer = TaskPublisher(exchange=Exchange('foo'))
+        producer = TaskPublisher(exchange=Exchange('foo'), app=self.app)
         self.assertEqual(producer.exchange.name, 'foo')
 
 

+ 2 - 2
celery/tests/app/test_annotations.py

@@ -13,12 +13,12 @@ class MyAnnotation(object):
 class AnnotationCase(AppCase):
 
     def setup(self):
-        @self.app.task()
+        @self.app.task(shared=False)
         def add(x, y):
             return x + y
         self.add = add
 
-        @self.app.task()
+        @self.app.task(shared=False)
         def mul(x, y):
             return x * y
         self.mul = mul

+ 100 - 147
celery/tests/app/test_app.py

@@ -7,7 +7,7 @@ from pickle import loads, dumps
 
 from kombu import Exchange
 
-from celery import Celery, shared_task, current_app
+from celery import shared_task, current_app
 from celery import app as _app
 from celery import _state
 from celery.app import base as _appbase
@@ -20,11 +20,13 @@ from celery.utils.serialization import pickle
 
 from celery.tests import config
 from celery.tests.case import (
-    AppCase, Case,
+    AppCase,
+    depends_on_current_app,
     mask_modules,
     platform_pyimp,
     sys_platform,
     pypy_version,
+    with_environ,
 )
 from celery.utils import uuid
 from celery.utils.mail import ErrorMail
@@ -55,42 +57,41 @@ def _get_test_config():
 test_config = _get_test_config()
 
 
-class test_module(Case):
+class test_module(AppCase):
 
     def test_default_app(self):
         self.assertEqual(_app.default_app, _state.default_app)
 
     def test_bugreport(self):
-        self.assertTrue(_app.bugreport())
+        self.assertTrue(_app.bugreport(app=self.app))
 
 
 class test_App(AppCase):
 
     def setup(self):
-        self.app.conf.update(test_config)
+        self.app.add_defaults(test_config)
 
     def test_task(self):
-        app = Celery('foozibari', set_as_current=False)
+        with self.Celery('foozibari') as app:
 
-        def fun():
-            pass
+            def fun():
+                pass
 
-        fun.__module__ = '__main__'
-        task = app.task(fun)
-        self.assertEqual(task.name, app.main + '.fun')
+            fun.__module__ = '__main__'
+            task = app.task(fun)
+            self.assertEqual(task.name, app.main + '.fun')
 
     def test_with_config_source(self):
-        app = Celery(set_as_current=False, config_source=ObjectConfig)
-        self.assertEqual(app.conf.FOO, 1)
-        self.assertEqual(app.conf.BAR, 2)
+        with self.Celery(config_source=ObjectConfig) as app:
+            self.assertEqual(app.conf.FOO, 1)
+            self.assertEqual(app.conf.BAR, 2)
 
+    @depends_on_current_app
     def test_task_windows_execv(self):
-        app = Celery(set_as_current=False)
-
         prev, _appbase._EXECV = _appbase._EXECV, True
         try:
 
-            @app.task()
+            @self.app.task(shared=False)
             def foo():
                 pass
 
@@ -101,41 +102,36 @@ class test_App(AppCase):
         assert not _appbase._EXECV
 
     def test_task_takes_no_args(self):
-        app = Celery(set_as_current=False)
-
         with self.assertRaises(TypeError):
-            @app.task(1)
+            @self.app.task(1)
             def foo():
                 pass
 
     def test_add_defaults(self):
-        app = Celery(set_as_current=False)
-
-        self.assertFalse(app.configured)
+        self.assertFalse(self.app.configured)
         _conf = {'FOO': 300}
         conf = lambda: _conf
-        app.add_defaults(conf)
-        self.assertIn(conf, app._pending_defaults)
-        self.assertFalse(app.configured)
-        self.assertEqual(app.conf.FOO, 300)
-        self.assertTrue(app.configured)
-        self.assertFalse(app._pending_defaults)
+        self.app.add_defaults(conf)
+        self.assertIn(conf, self.app._pending_defaults)
+        self.assertFalse(self.app.configured)
+        self.assertEqual(self.app.conf.FOO, 300)
+        self.assertTrue(self.app.configured)
+        self.assertFalse(self.app._pending_defaults)
 
         # defaults not pickled
-        appr = loads(dumps(app))
+        appr = loads(dumps(self.app))
         with self.assertRaises(AttributeError):
             appr.conf.FOO
 
         # add more defaults after configured
         conf2 = {'FOO': 'BAR'}
-        app.add_defaults(conf2)
-        self.assertEqual(app.conf.FOO, 'BAR')
+        self.app.add_defaults(conf2)
+        self.assertEqual(self.app.conf.FOO, 'BAR')
 
-        self.assertIn(_conf, app.conf.defaults)
-        self.assertIn(conf2, app.conf.defaults)
+        self.assertIn(_conf, self.app.conf.defaults)
+        self.assertIn(conf2, self.app.conf.defaults)
 
     def test_connection_or_acquire(self):
-
         with self.app.connection_or_acquire(block=True):
             self.assertTrue(self.app.pool._dirty)
 
@@ -173,118 +169,90 @@ class test_App(AppCase):
             self.app.autodiscover_tasks(['proj.A', 'proj.B'])
             self.assertFalse(ep.called)
 
+    @with_environ('CELERY_BROKER_URL', '')
     def test_with_broker(self):
-        prev = os.environ.get('CELERY_BROKER_URL')
-        os.environ.pop('CELERY_BROKER_URL', None)
-        try:
-            app = Celery(set_as_current=False, broker='foo://baribaz')
+        with self.Celery(broker='foo://baribaz') as app:
             self.assertEqual(app.conf.BROKER_HOST, 'foo://baribaz')
-        finally:
-            os.environ['CELERY_BROKER_URL'] = prev
 
     def test_repr(self):
         self.assertTrue(repr(self.app))
 
     def test_custom_task_registry(self):
-        app1 = Celery(set_as_current=False)
-        app2 = Celery(set_as_current=False, tasks=app1.tasks)
-        self.assertIs(app2.tasks, app1.tasks)
+        with self.Celery(tasks=self.app.tasks) as app2:
+            self.assertIs(app2.tasks, self.app.tasks)
 
     def test_include_argument(self):
-        app = Celery(set_as_current=False, include=('foo', 'bar.foo'))
-        self.assertEqual(app.conf.CELERY_IMPORTS, ('foo', 'bar.foo'))
+        with self.Celery(include=('foo', 'bar.foo')) as app:
+            self.assertEqual(app.conf.CELERY_IMPORTS, ('foo', 'bar.foo'))
 
     def test_set_as_current(self):
         current = _state._tls.current_app
         try:
-            app = Celery(set_as_current=True)
+            app = self.Celery(set_as_current=True)
             self.assertIs(_state._tls.current_app, app)
         finally:
             _state._tls.current_app = current
 
     def test_current_task(self):
-        app = Celery(set_as_current=False)
-
-        @app.task
-        def foo():
+        @self.app.task
+        def foo(shared=False):
             pass
 
         _state._task_stack.push(foo)
         try:
-            self.assertEqual(app.current_task.name, foo.name)
+            self.assertEqual(self.app.current_task.name, foo.name)
         finally:
             _state._task_stack.pop()
 
     def test_task_not_shared(self):
         with patch('celery.app.base.shared_task') as sh:
-            app = Celery(set_as_current=False)
-
-            @app.task(shared=False)
+            @self.app.task(shared=False)
             def foo():
                 pass
             self.assertFalse(sh.called)
 
     def test_task_compat_with_filter(self):
-        app = Celery(set_as_current=False, accept_magic_kwargs=True)
-        check = Mock()
+        with self.Celery(accept_magic_kwargs=True) as app:
+            check = Mock()
 
-        def filter(task):
-            check(task)
-            return task
+            def filter(task):
+                check(task)
+                return task
 
-        @app.task(filter=filter)
-        def foo():
-            pass
-        check.assert_called_with(foo)
+            @app.task(filter=filter, shared=False)
+            def foo():
+                pass
+            check.assert_called_with(foo)
 
     def test_task_with_filter(self):
-        app = Celery(set_as_current=False, accept_magic_kwargs=False)
-        check = Mock()
+        with self.Celery(accept_magic_kwargs=False) as app:
+            check = Mock()
 
-        def filter(task):
-            check(task)
-            return task
+            def filter(task):
+                check(task)
+                return task
 
-        assert not _appbase._EXECV
+            assert not _appbase._EXECV
 
-        @app.task(filter=filter)
-        def foo():
-            pass
-        check.assert_called_with(foo)
+            @app.task(filter=filter, shared=False)
+            def foo():
+                pass
+            check.assert_called_with(foo)
 
     def test_task_sets_main_name_MP_MAIN_FILE(self):
         from celery import utils as _utils
         _utils.MP_MAIN_FILE = __file__
         try:
-            app = Celery('xuzzy', set_as_current=False)
+            with self.Celery('xuzzy') as app:
 
-            @app.task
-            def foo():
-                pass
+                @app.task
+                def foo():
+                    pass
 
-            self.assertEqual(foo.name, 'xuzzy.foo')
+                self.assertEqual(foo.name, 'xuzzy.foo')
         finally:
             _utils.MP_MAIN_FILE = None
 
-    def test_base_task_inherits_magic_kwargs_from_app(self):
-        from celery.task import Task as OldTask
-
-        class timkX(OldTask):
-            abstract = True
-
-        app = Celery(set_as_current=False, accept_magic_kwargs=True)
-        timkX.bind(app)
-        # see #918
-        self.assertFalse(timkX.accept_magic_kwargs)
-
-        from celery import Task as NewTask
-
-        class timkY(NewTask):
-            abstract = True
-
-        timkY.bind(app)
-        self.assertFalse(timkY.accept_magic_kwargs)
-
     def test_annotate_decorator(self):
         from celery.app.task import Task
 
@@ -303,12 +271,11 @@ class test_App(AppCase):
                 return fun(*args, **kwargs)
             return _inner
 
-        app = Celery(set_as_current=False)
-        app.conf.CELERY_ANNOTATIONS = {
+        self.app.conf.CELERY_ANNOTATIONS = {
             adX.name: {'@__call__': deco}
         }
-        adX.bind(app)
-        self.assertIs(adX.app, app)
+        adX.bind(self.app)
+        self.assertIs(adX.app, self.app)
 
         i = adX()
         i(2, 4, x=3)
@@ -318,9 +285,7 @@ class test_App(AppCase):
         i.annotate()
 
     def test_apply_async_has__self__(self):
-        app = Celery(set_as_current=False)
-
-        @app.task(__self__='hello')
+        @self.app.task(__self__='hello', shared=False)
         def aawsX():
             pass
 
@@ -330,25 +295,22 @@ class test_App(AppCase):
             self.assertEqual(args, ('hello', 4, 5))
 
     def test_apply_async__connection_arg(self):
-        app = Celery(set_as_current=False)
-
-        @app.task()
+        @self.app.task(shared=False)
         def aacaX():
             pass
 
-        connection = app.connection('asd://')
+        connection = self.app.connection('asd://')
         with self.assertRaises(KeyError):
             aacaX.apply_async(connection=connection)
 
     def test_apply_async_adds_children(self):
         from celery._state import _task_stack
-        app = Celery(set_as_current=False)
 
-        @app.task()
+        @self.app.task(shared=False)
         def a3cX1(self):
             pass
 
-        @app.task()
+        @self.app.task(shared=False)
         def a3cX2(self):
             pass
 
@@ -363,11 +325,6 @@ class test_App(AppCase):
         finally:
             _task_stack.pop()
 
-    def test_TaskSet(self):
-        ts = self.app.TaskSet()
-        self.assertListEqual(ts.tasks, [])
-        self.assertIs(ts.app, self.app)
-
     def test_pickle_app(self):
         changes = dict(THE_FOO_BAR='bars',
                        THE_MII_MAR='jars')
@@ -461,19 +418,21 @@ class test_App(AppCase):
         x = self.app.Worker
         self.assertIs(x.app, self.app)
 
+    @depends_on_current_app
     def test_AsyncResult(self):
         x = self.app.AsyncResult('1')
         self.assertIs(x.app, self.app)
         r = loads(dumps(x))
         # not set as current, so ends up as default app after reduce
-        self.assertIs(r.app, _state.default_app)
+        self.assertIs(r.app, current_app._get_current_object())
 
     def test_get_active_apps(self):
         self.assertTrue(list(_state._get_active_apps()))
 
-        app1 = Celery(set_as_current=False)
+        app1 = self.Celery()
         appid = id(app1)
         self.assertIn(app1, _state._get_active_apps())
+        app1.close()
         del(app1)
 
         # weakref removed from list when app goes out of scope.
@@ -607,7 +566,7 @@ class test_App(AppCase):
         self.assertFalse(task.app.mail_admins.called)
 
 
-class test_defaults(Case):
+class test_defaults(AppCase):
 
     def test_str_to_bool(self):
         for s in ('false', 'no', '0'):
@@ -618,7 +577,7 @@ class test_defaults(Case):
             defaults.strtobool('unsure')
 
 
-class test_debugging_utils(Case):
+class test_debugging_utils(AppCase):
 
     def test_enable_disable_trace(self):
         try:
@@ -630,7 +589,7 @@ class test_debugging_utils(Case):
             _app.disable_trace()
 
 
-class test_pyimplementation(Case):
+class test_pyimplementation(AppCase):
 
     def test_platform_python_implementation(self):
         with platform_pyimp(lambda: 'Xython'):
@@ -656,36 +615,30 @@ class test_pyimplementation(Case):
                     self.assertEqual('CPython', pyimplementation())
 
 
-class test_shared_task(Case):
-
-    def setUp(self):
-        self._restore_app = current_app._get_current_object()
-
-    def tearDown(self):
-        self._restore_app.set_current()
+class test_shared_task(AppCase):
 
     def test_registers_to_all_apps(self):
-        xproj = Celery('xproj')
-        xproj.finalize()
+        with self.Celery('xproj', set_as_current=True) as xproj:
+            xproj.finalize()
 
-        @shared_task
-        def foo():
-            return 42
+            @shared_task
+            def foo():
+                return 42
 
-        @shared_task()
-        def bar():
-            return 84
+            @shared_task()
+            def bar():
+                return 84
 
-        self.assertIs(foo.app, xproj)
-        self.assertIs(bar.app, xproj)
-        self.assertTrue(foo._get_current_object())
+            self.assertIs(foo.app, xproj)
+            self.assertIs(bar.app, xproj)
+            self.assertTrue(foo._get_current_object())
 
-        yproj = Celery('yproj')
-        self.assertIs(foo.app, yproj)
-        self.assertIs(bar.app, yproj)
+            with self.Celery('yproj', set_as_current=True) as yproj:
+                self.assertIs(foo.app, yproj)
+                self.assertIs(bar.app, yproj)
 
-        @shared_task()
-        def baz():
-            return 168
+                @shared_task()
+                def baz():
+                    return 168
 
-        self.assertIs(baz.app, yproj)
+                self.assertIs(baz.app, yproj)

+ 34 - 37
celery/tests/app/test_beat.py

@@ -8,12 +8,10 @@ from nose import SkipTest
 from pickle import dumps, loads
 
 from celery import beat
-from celery import task
 from celery.five import keys, string_t
-from celery.result import AsyncResult
 from celery.schedules import schedule
 from celery.utils import uuid
-from celery.tests.case import AppCase, patch_settings
+from celery.tests.case import AppCase
 
 
 class Object(object):
@@ -49,10 +47,13 @@ class test_ScheduleEntry(AppCase):
     Entry = beat.ScheduleEntry
 
     def create_entry(self, **kwargs):
-        entry = dict(name='celery.unittest.add',
-                     schedule=schedule(timedelta(seconds=10)),
-                     args=(2, 2),
-                     options={'routing_key': 'cpu'})
+        entry = dict(
+            name='celery.unittest.add',
+            schedule=timedelta(seconds=10),
+            args=(2, 2),
+            options={'routing_key': 'cpu'},
+            app=self.app,
+        )
         return self.Entry(**dict(entry, **kwargs))
 
     def test_next(self):
@@ -68,6 +69,8 @@ class test_ScheduleEntry(AppCase):
 
     def test_is_due(self):
         entry = self.create_entry(schedule=timedelta(seconds=10))
+        self.assertIs(entry.app, self.app)
+        self.assertIs(entry.schedule.app, self.app)
         due1, next_time_to_run1 = entry.is_due()
         self.assertFalse(due1)
         self.assertGreater(next_time_to_run1, 9)
@@ -111,7 +114,7 @@ class mScheduler(beat.Scheduler):
                           'args': args,
                           'kwargs': kwargs,
                           'options': options})
-        return AsyncResult(uuid(), app=self.app)
+        return self.app.AsyncResult(uuid())
 
 
 class mSchedulerSchedulingError(mScheduler):
@@ -151,19 +154,19 @@ class test_Scheduler(AppCase):
 
     def test_apply_async_uses_registered_task_instances(self):
 
-        @self.app.task
+        @self.app.task(shared=False)
         def foo():
             pass
         foo.apply_async = Mock(name='foo.apply_async')
         assert foo.name in foo._get_app().tasks
 
         scheduler = mScheduler(app=self.app)
-        scheduler.apply_async(scheduler.Entry(task=foo.name))
+        scheduler.apply_async(scheduler.Entry(task=foo.name, app=self.app))
         self.assertTrue(foo.apply_async.called)
 
     def test_apply_async_should_not_sync(self):
 
-        @task()
+        @self.app.task(shared=False)
         def not_sync():
             pass
         not_sync.apply_async = Mock()
@@ -172,12 +175,12 @@ class test_Scheduler(AppCase):
         s._do_sync = Mock()
         s.should_sync = Mock()
         s.should_sync.return_value = True
-        s.apply_async(s.Entry(task=not_sync.name))
+        s.apply_async(s.Entry(task=not_sync.name, app=self.app))
         s._do_sync.assert_called_with()
 
         s._do_sync = Mock()
         s.should_sync.return_value = False
-        s.apply_async(s.Entry(task=not_sync.name))
+        s.apply_async(s.Entry(task=not_sync.name, app=self.app))
         self.assertFalse(s._do_sync.called)
 
     @patch('celery.app.base.Celery.send_task')
@@ -192,7 +195,7 @@ class test_Scheduler(AppCase):
 
     def test_maybe_entry(self):
         s = mScheduler(app=self.app)
-        entry = s.Entry(name='add every', task='tasks.add')
+        entry = s.Entry(name='add every', task='tasks.add', app=self.app)
         self.assertIs(s._maybe_entry(entry.name, entry), entry)
         self.assertTrue(s._maybe_entry('add every', {
             'task': 'tasks.add',
@@ -213,29 +216,23 @@ class test_Scheduler(AppCase):
         callback(KeyError(), 5)
 
     def test_install_default_entries(self):
-        with patch_settings(self.app,
-                            CELERY_TASK_RESULT_EXPIRES=None,
-                            CELERYBEAT_SCHEDULE={}):
-            s = mScheduler(app=self.app)
-            s.install_default_entries({})
-            self.assertNotIn('celery.backend_cleanup', s.data)
+        self.app.conf.CELERY_TASK_RESULT_EXPIRES = None
+        self.app.conf.CELERYBEAT_SCHEDULE = {}
+        s = mScheduler(app=self.app)
+        s.install_default_entries({})
+        self.assertNotIn('celery.backend_cleanup', s.data)
         self.app.backend.supports_autoexpire = False
-        with patch_settings(self.app,
-                            CELERY_TASK_RESULT_EXPIRES=30,
-                            CELERYBEAT_SCHEDULE={}):
-            s = mScheduler(app=self.app)
-            s.install_default_entries({})
-            self.assertIn('celery.backend_cleanup', s.data)
+
+        self.app.conf.CELERY_TASK_RESULT_EXPIRES = 30
+        s = mScheduler(app=self.app)
+        s.install_default_entries({})
+        self.assertIn('celery.backend_cleanup', s.data)
+
         self.app.backend.supports_autoexpire = True
-        try:
-            with patch_settings(self.app,
-                                CELERY_TASK_RESULT_EXPIRES=31,
-                                CELERYBEAT_SCHEDULE={}):
-                s = mScheduler(app=self.app)
-                s.install_default_entries({})
-                self.assertNotIn('celery.backend_cleanup', s.data)
-        finally:
-            self.app.backend.supports_autoexpire = False
+        self.app.conf.CELERY_TASK_RESULT_EXPIRES = 31
+        s = mScheduler(app=self.app)
+        s.install_default_entries({})
+        self.assertNotIn('celery.backend_cleanup', s.data)
 
     def test_due_tick(self):
         scheduler = mScheduler(app=self.app)
@@ -490,7 +487,7 @@ class test_EmbeddedService(AppCase):
 class test_schedule(AppCase):
 
     def test_maybe_make_aware(self):
-        x = schedule(10)
+        x = schedule(10, app=self.app)
         x.utc_enabled = True
         d = x.maybe_make_aware(datetime.utcnow())
         self.assertTrue(d.tzinfo)
@@ -499,7 +496,7 @@ class test_schedule(AppCase):
         self.assertIsNone(d2.tzinfo)
 
     def test_to_local(self):
-        x = schedule(10)
+        x = schedule(10, app=self.app)
         x.utc_enabled = True
         d = x.to_local(datetime.utcnow())
         self.assertIsNone(d.tzinfo)

+ 3 - 3
celery/tests/app/test_builtins.py

@@ -38,7 +38,7 @@ class test_map(BuiltinsCase):
 
     def test_run(self):
 
-        @self.app.task()
+        @self.app.task(shared=False)
         def map_mul(x):
             return x[0] * x[1]
 
@@ -52,7 +52,7 @@ class test_starmap(BuiltinsCase):
 
     def test_run(self):
 
-        @self.app.task()
+        @self.app.task(shared=False)
         def smap_mul(x, y):
             return x * y
 
@@ -67,7 +67,7 @@ class test_chunks(BuiltinsCase):
     @patch('celery.canvas.chunks.apply_chunks')
     def test_run(self, apply_chunks):
 
-        @self.app.task()
+        @self.app.task(shared=False)
         def chunks_mul(l):
             return l
 

+ 2 - 2
celery/tests/app/test_celery.py

@@ -1,10 +1,10 @@
 from __future__ import absolute_import
-from celery.tests.case import Case
+from celery.tests.case import AppCase
 
 import celery
 
 
-class test_celery_package(Case):
+class test_celery_package(AppCase):
 
     def test_version(self):
         self.assertTrue(celery.VERSION)

+ 3 - 9
celery/tests/app/test_control.py

@@ -8,7 +8,7 @@ from kombu.pidbox import Mailbox
 
 from celery.app import control
 from celery.utils import uuid
-from celery.tests.case import AppCase, Case
+from celery.tests.case import AppCase
 
 
 class MockMailbox(Mailbox):
@@ -40,7 +40,7 @@ def with_mock_broadcast(fun):
     return _resets
 
 
-class test_flatten_reply(Case):
+class test_flatten_reply(AppCase):
 
     def test_flatten_reply(self):
         reply = [
@@ -65,9 +65,6 @@ class test_inspect(AppCase):
         self.prev, self.app.control = self.app.control, self.c
         self.i = self.c.inspect()
 
-    def tearDown(self):
-        self.app.control = self.prev
-
     def test_prepare_reply(self):
         self.assertDictEqual(self.i._prepare([{'w1': {'ok': 1}},
                                               {'w2': {'ok': 1}}]),
@@ -159,14 +156,11 @@ class test_Broadcast(AppCase):
         self.control = Control(app=self.app)
         self.app.control = self.control
 
-        @self.app.task()
+        @self.app.task(shared=False)
         def mytask():
             pass
         self.mytask = mytask
 
-    def tearDown(self):
-        del(self.app.control)
-
     def test_purge(self):
         self.control.purge()
 

+ 4 - 4
celery/tests/app/test_defaults.py

@@ -7,15 +7,15 @@ from mock import Mock, patch
 
 from celery.app.defaults import NAMESPACES
 
-from celery.tests.case import Case, pypy_version, sys_platform
+from celery.tests.case import AppCase, pypy_version, sys_platform
 
 
-class test_defaults(Case):
+class test_defaults(AppCase):
 
-    def setUp(self):
+    def setup(self):
         self._prev = sys.modules.pop('celery.app.defaults', None)
 
-    def tearDown(self):
+    def teardown(self):
         if self._prev:
             sys.modules['celery.app.defaults'] = self._prev
 

+ 2 - 2
celery/tests/app/test_exceptions.py

@@ -6,10 +6,10 @@ from datetime import datetime
 
 from celery.exceptions import RetryTaskError
 
-from celery.tests.case import Case
+from celery.tests.case import AppCase
 
 
-class test_RetryTaskError(Case):
+class test_RetryTaskError(AppCase):
 
     def test_when_datetime(self):
         x = RetryTaskError('foo', KeyError(), when=datetime.utcnow())

+ 27 - 36
celery/tests/app/test_loaders.py

@@ -17,7 +17,9 @@ from celery.loaders.app import AppLoader
 from celery.utils.imports import NotAPackage
 from celery.utils.mail import SendmailWarning
 
-from celery.tests.case import AppCase, Case
+from celery.tests.case import (
+    AppCase, Case, depends_on_current_app, with_environ,
+)
 
 
 class DummyLoader(base.BaseLoader):
@@ -32,15 +34,15 @@ class test_loaders(AppCase):
         self.assertEqual(loaders.get_loader_cls('default'),
                          default.Loader)
 
+    @depends_on_current_app
     def test_current_loader(self):
-        self.app.set_current()  # XXX Compat test
         with self.assertWarnsRegex(
                 CPendingDeprecationWarning,
                 r'deprecation'):
             self.assertIs(loaders.current_loader(), self.app.loader)
 
+    @depends_on_current_app
     def test_load_settings(self):
-        self.app.set_current()  # XXX Compat test
         with self.assertWarnsRegex(
                 CPendingDeprecationWarning,
                 r'deprecation'):
@@ -105,15 +107,11 @@ class test_LoaderBase(AppCase):
 
     def test_import_default_modules(self):
         modnames = lambda l: [m.__name__ for m in l]
-        prev, self.app.conf.CELERY_IMPORTS = (
-            self.app.conf.CELERY_IMPORTS, ('os', 'sys'))
-        try:
-            self.assertEqual(
-                sorted(modnames(self.loader.import_default_modules())),
-                sorted(modnames([os, sys])),
-            )
-        finally:
-            self.app.conf.CELERY_IMPORTS = prev
+        self.app.conf.CELERY_IMPORTS = ('os', 'sys')
+        self.assertEqual(
+            sorted(modnames(self.loader.import_default_modules())),
+            sorted(modnames([os, sys])),
+        )
 
     def test_import_from_cwd_custom_imp(self):
 
@@ -163,19 +161,15 @@ class test_DefaultLoader(AppCase):
         find_module.side_effect = NotAPackage()
         l = default.Loader(app=self.app)
         with self.assertRaises(NotAPackage):
-            l.read_configuration()
+            l.read_configuration(fail_silently=False)
 
     @patch('celery.loaders.base.find_module')
+    @with_environ('CELERY_CONFIG_MODULE', 'celeryconfig.py')
     def test_read_configuration_py_in_name(self, find_module):
-        prev = os.environ['CELERY_CONFIG_MODULE']
-        os.environ['CELERY_CONFIG_MODULE'] = 'celeryconfig.py'
-        try:
-            find_module.side_effect = NotAPackage()
-            l = default.Loader(app=self.app)
-            with self.assertRaises(NotAPackage):
-                l.read_configuration()
-        finally:
-            os.environ['CELERY_CONFIG_MODULE'] = prev
+        find_module.side_effect = NotAPackage()
+        l = default.Loader(app=self.app)
+        with self.assertRaises(NotAPackage):
+            l.read_configuration(fail_silently=False)
 
     @patch('celery.loaders.base.find_module')
     def test_read_configuration_importerror(self, find_module):
@@ -183,9 +177,9 @@ class test_DefaultLoader(AppCase):
         find_module.side_effect = ImportError()
         l = default.Loader(app=self.app)
         with self.assertWarnsRegex(NotConfigured, r'make sure it exists'):
-            l.read_configuration()
+            l.read_configuration(fail_silently=True)
         default.C_WNOCONF = False
-        l.read_configuration()
+        l.read_configuration(fail_silently=True)
 
     def test_read_configuration(self):
         from types import ModuleType
@@ -193,17 +187,18 @@ class test_DefaultLoader(AppCase):
         class ConfigModule(ModuleType):
             pass
 
-        celeryconfig = ConfigModule('celeryconfig')
-        celeryconfig.CELERY_IMPORTS = ('os', 'sys')
         configname = os.environ.get('CELERY_CONFIG_MODULE') or 'celeryconfig'
+        celeryconfig = ConfigModule(configname)
+        celeryconfig.CELERY_IMPORTS = ('os', 'sys')
 
         prevconfig = sys.modules.get(configname)
         sys.modules[configname] = celeryconfig
         try:
             l = default.Loader(app=self.app)
-            settings = l.read_configuration()
+            l.find_module = Mock(name='find_module')
+            settings = l.read_configuration(fail_silently=False)
             self.assertTupleEqual(settings.CELERY_IMPORTS, ('os', 'sys'))
-            settings = l.read_configuration()
+            settings = l.read_configuration(fail_silently=False)
             self.assertTupleEqual(settings.CELERY_IMPORTS, ('os', 'sys'))
             l.on_worker_init()
         finally:
@@ -248,14 +243,10 @@ class test_AppLoader(AppCase):
         self.loader = AppLoader(app=self.app)
 
     def test_on_worker_init(self):
-        prev, self.app.conf.CELERY_IMPORTS = (
-            self.app.conf.CELERY_IMPORTS, ('subprocess', ))
-        try:
-            sys.modules.pop('subprocess', None)
-            self.loader.init_worker()
-            self.assertIn('subprocess', sys.modules)
-        finally:
-            self.app.conf.CELERY_IMPORTS = prev
+        self.app.conf.CELERY_IMPORTS = ('subprocess', )
+        sys.modules.pop('subprocess', None)
+        self.loader.init_worker()
+        self.assertIn('subprocess', sys.modules)
 
 
 class test_autodiscovery(Case):

+ 8 - 11
celery/tests/app/test_log.py

@@ -8,7 +8,7 @@ from mock import patch, Mock
 from nose import SkipTest
 
 from celery import signals
-from celery.app.log import Logging, TaskFormatter
+from celery.app.log import TaskFormatter
 from celery.utils.log import LoggingProxy
 from celery.utils import uuid
 from celery.utils.log import (
@@ -20,12 +20,12 @@ from celery.utils.log import (
     _patch_logger_class,
 )
 from celery.tests.case import (
-    AppCase, Case, override_stdouts, wrap_logger, get_handlers,
+    AppCase, override_stdouts, wrap_logger, get_handlers,
     restore_logging,
 )
 
 
-class test_TaskFormatter(Case):
+class test_TaskFormatter(AppCase):
 
     def test_no_task(self):
         class Record(object):
@@ -43,7 +43,7 @@ class test_TaskFormatter(Case):
         self.assertEqual(record.task_id, '???')
 
 
-class test_ColorFormatter(Case):
+class test_ColorFormatter(AppCase):
 
     @patch('celery.utils.log.safe_str')
     @patch('logging.Formatter.formatException')
@@ -139,10 +139,7 @@ class test_default_logger(AppCase):
     def test_setup_logging_subsystem_misc2(self):
         with restore_logging():
             self.app.conf.CELERYD_HIJACK_ROOT_LOGGER = True
-            try:
-                self.app.log.setup_logging_subsystem()
-            finally:
-                self.app.conf.CELERYD_HIJACK_ROOT_LOGGER = False
+            self.app.log.setup_logging_subsystem()
 
     def test_get_default_logger(self):
         self.assertTrue(self.app.log.get_default_logger())
@@ -277,7 +274,7 @@ class test_task_logger(test_default_logger):
         logging.root.manager.loggerDict.pop(logger.name, None)
         self.uid = uuid()
 
-        @self.app.task
+        @self.app.task(shared=False)
         def test_task():
             pass
         self.get_logger().handlers = []
@@ -285,7 +282,7 @@ class test_task_logger(test_default_logger):
         from celery._state import _task_stack
         _task_stack.push(test_task)
 
-    def tearDown(self):
+    def teardown(self):
         from celery._state import _task_stack
         _task_stack.pop()
 
@@ -296,7 +293,7 @@ class test_task_logger(test_default_logger):
         return get_task_logger("test_task_logger")
 
 
-class test_patch_logger_cls(Case):
+class test_patch_logger_cls(AppCase):
 
     def test_patches(self):
         _patch_logger_class()

+ 34 - 48
celery/tests/app/test_registry.py

@@ -1,51 +1,40 @@
 from __future__ import absolute_import
 
-from celery import Celery
-from celery.app.registry import (
-    TaskRegistry,
-    _unpickle_task,
-    _unpickle_task_v2,
-)
-from celery.task import Task, PeriodicTask
-from celery.tests.case import AppCase, Case
+from celery.app.registry import _unpickle_task, _unpickle_task_v2
+from celery.tests.case import AppCase, depends_on_current_app
 
 
-class MockTask(Task):
-    name = 'celery.unittest.test_task'
-
-    def run(self, **kwargs):
-        return True
-
-
-class MockPeriodicTask(PeriodicTask):
-    name = 'celery.unittest.test_periodic_task'
-    run_every = 10
-
-    def run(self, **kwargs):
-        return True
+def returns():
+    return 1
 
 
 class test_unpickle_task(AppCase):
 
-    def setup(self):
-        self.app = Celery(set_as_current=True)
-
+    @depends_on_current_app
     def test_unpickle_v1(self):
         self.app.tasks['txfoo'] = 'bar'
         self.assertEqual(_unpickle_task('txfoo'), 'bar')
 
+    @depends_on_current_app
     def test_unpickle_v2(self):
         self.app.tasks['txfoo1'] = 'bar1'
         self.assertEqual(_unpickle_task_v2('txfoo1'), 'bar1')
         self.assertEqual(_unpickle_task_v2('txfoo1', module='celery'), 'bar1')
 
 
-class test_TaskRegistry(Case):
+class test_TaskRegistry(AppCase):
+
+    def setup(self):
+        self.mytask = self.app.task(name='A', shared=False)(returns)
+        self.myperiodic = self.app.task(
+            name='B', shared=False, type='periodic',
+        )(returns)
 
     def test_NotRegistered_str(self):
-        self.assertTrue(repr(TaskRegistry.NotRegistered('tasks.add')))
+        self.assertTrue(repr(self.app.tasks.NotRegistered('tasks.add')))
 
     def assertRegisterUnregisterCls(self, r, task):
+        r.unregister(task)
         with self.assertRaises(r.NotRegistered):
             r.unregister(task)
         r.register(task)
@@ -58,35 +47,32 @@ class test_TaskRegistry(Case):
         self.assertIn(task_name, r)
 
     def test_task_registry(self):
-        r = TaskRegistry()
+        r = self.app._tasks
         self.assertIsInstance(r, dict, 'TaskRegistry is mapping')
 
-        self.assertRegisterUnregisterCls(r, MockTask)
-        self.assertRegisterUnregisterCls(r, MockPeriodicTask)
+        self.assertRegisterUnregisterCls(r, self.mytask)
+        self.assertRegisterUnregisterCls(r, self.myperiodic)
 
-        r.register(MockPeriodicTask)
-        r.unregister(MockPeriodicTask.name)
-        self.assertNotIn(MockPeriodicTask, r)
-        r.register(MockPeriodicTask)
+        r.register(self.myperiodic)
+        r.unregister(self.myperiodic.name)
+        self.assertNotIn(self.myperiodic, r)
+        r.register(self.myperiodic)
 
         tasks = dict(r)
-        self.assertIsInstance(tasks.get(MockTask.name), MockTask)
-        self.assertIsInstance(tasks.get(MockPeriodicTask.name),
-                              MockPeriodicTask)
+        self.assertIs(tasks.get(self.mytask.name), self.mytask)
+        self.assertIs(tasks.get(self.myperiodic.name), self.myperiodic)
 
-        self.assertIsInstance(r[MockTask.name], MockTask)
-        self.assertIsInstance(r[MockPeriodicTask.name],
-                              MockPeriodicTask)
+        self.assertIs(r[self.mytask.name], self.mytask)
+        self.assertIs(r[self.myperiodic.name], self.myperiodic)
 
-        r.unregister(MockTask)
-        self.assertNotIn(MockTask.name, r)
-        r.unregister(MockPeriodicTask)
-        self.assertNotIn(MockPeriodicTask.name, r)
+        r.unregister(self.mytask)
+        self.assertNotIn(self.mytask.name, r)
+        r.unregister(self.myperiodic)
+        self.assertNotIn(self.myperiodic.name, r)
 
-        self.assertTrue(MockTask().run())
-        self.assertTrue(MockPeriodicTask().run())
+        self.assertTrue(self.mytask.run())
+        self.assertTrue(self.myperiodic.run())
 
     def test_compat(self):
-        r = TaskRegistry()
-        r.regular()
-        r.periodic()
+        self.assertTrue(self.app.tasks.regular())
+        self.assertTrue(self.app.tasks.periodic())

+ 62 - 70
celery/tests/app/test_routes.py

@@ -1,7 +1,5 @@
 from __future__ import absolute_import
 
-from contextlib import contextmanager
-
 from kombu import Exchange
 from kombu.utils.functional import maybe_evaluate
 
@@ -22,17 +20,9 @@ def E(app, queues):
     return expand
 
 
-@contextmanager
-def _queues(app, **queues):
-    prev_queues = app.conf.CELERY_QUEUES
-    prev_Queues = app.amqp.queues
+def set_queues(app, **queues):
     app.conf.CELERY_QUEUES = queues
     app.amqp.queues = app.amqp.Queues(queues)
-    try:
-        yield
-    finally:
-        app.conf.CELERY_QUEUES = prev_queues
-        app.amqp.queues = prev_Queues
 
 
 class RouteCase(AppCase):
@@ -54,7 +44,7 @@ class RouteCase(AppCase):
             'routing_key': self.app.conf.CELERY_DEFAULT_ROUTING_KEY,
         }
 
-        @self.app.task()
+        @self.app.task(shared=False)
         def mytask():
             pass
         self.mytask = mytask
@@ -63,24 +53,24 @@ class RouteCase(AppCase):
 class test_MapRoute(RouteCase):
 
     def test_route_for_task_expanded_route(self):
-        with _queues(self.app, foo=self.a_queue, bar=self.b_queue):
-            expand = E(self.app, self.app.amqp.queues)
-            route = routes.MapRoute({self.mytask.name: {'queue': 'foo'}})
-            self.assertEqual(
-                expand(route.route_for_task(self.mytask.name))['queue'].name,
-                'foo',
-            )
-            self.assertIsNone(route.route_for_task('celery.awesome'))
+        set_queues(self.app, foo=self.a_queue, bar=self.b_queue)
+        expand = E(self.app, self.app.amqp.queues)
+        route = routes.MapRoute({self.mytask.name: {'queue': 'foo'}})
+        self.assertEqual(
+            expand(route.route_for_task(self.mytask.name))['queue'].name,
+            'foo',
+        )
+        self.assertIsNone(route.route_for_task('celery.awesome'))
 
     def test_route_for_task(self):
-        with _queues(self.app, foo=self.a_queue, bar=self.b_queue):
-            expand = E(self.app, self.app.amqp.queues)
-            route = routes.MapRoute({self.mytask.name: self.b_queue})
-            self.assertDictContainsSubset(
-                self.b_queue,
-                expand(route.route_for_task(self.mytask.name)),
-            )
-            self.assertIsNone(route.route_for_task('celery.awesome'))
+        set_queues(self.app, foo=self.a_queue, bar=self.b_queue)
+        expand = E(self.app, self.app.amqp.queues)
+        route = routes.MapRoute({self.mytask.name: self.b_queue})
+        self.assertDictContainsSubset(
+            self.b_queue,
+            expand(route.route_for_task(self.mytask.name)),
+        )
+        self.assertIsNone(route.route_for_task('celery.awesome'))
 
     def test_expand_route_not_found(self):
         expand = E(self.app, self.app.amqp.Queues(
@@ -97,54 +87,56 @@ class test_lookup_route(RouteCase):
         self.assertDictEqual(router.queues, {})
 
     def test_lookup_takes_first(self):
-        with _queues(self.app, foo=self.a_queue, bar=self.b_queue):
-            R = routes.prepare(({self.mytask.name: {'queue': 'bar'}},
-                                {self.mytask.name: {'queue': 'foo'}}))
-            router = Router(self.app, R, self.app.amqp.queues)
-            self.assertEqual(router.route({}, self.mytask.name,
-                             args=[1, 2], kwargs={})['queue'].name, 'bar')
+        set_queues(self.app, foo=self.a_queue, bar=self.b_queue)
+        R = routes.prepare(({self.mytask.name: {'queue': 'bar'}},
+                            {self.mytask.name: {'queue': 'foo'}}))
+        router = Router(self.app, R, self.app.amqp.queues)
+        self.assertEqual(router.route({}, self.mytask.name,
+                         args=[1, 2], kwargs={})['queue'].name, 'bar')
 
     def test_expands_queue_in_options(self):
-        with _queues(self.app):
-            R = routes.prepare(())
-            router = Router(
-                self.app, R, self.app.amqp.queues, create_missing=True,
-            )
-            # apply_async forwards all arguments, even exchange=None etc,
-            # so need to make sure it's merged correctly.
-            route = router.route(
-                {'queue': 'testq',
-                 'exchange': None,
-                 'routing_key': None,
-                 'immediate': False},
-                self.mytask.name,
-                args=[1, 2], kwargs={},
-            )
-            self.assertEqual(route['queue'].name, 'testq')
-            self.assertEqual(route['queue'].exchange, Exchange('testq'))
-            self.assertEqual(route['queue'].routing_key, 'testq')
-            self.assertEqual(route['immediate'], False)
+        set_queues(self.app)
+        R = routes.prepare(())
+        router = Router(
+            self.app, R, self.app.amqp.queues, create_missing=True,
+        )
+        # apply_async forwards all arguments, even exchange=None etc,
+        # so need to make sure it's merged correctly.
+        route = router.route(
+            {'queue': 'testq',
+             'exchange': None,
+             'routing_key': None,
+             'immediate': False},
+            self.mytask.name,
+            args=[1, 2], kwargs={},
+        )
+        self.assertEqual(route['queue'].name, 'testq')
+        self.assertEqual(route['queue'].exchange, Exchange('testq'))
+        self.assertEqual(route['queue'].routing_key, 'testq')
+        self.assertEqual(route['immediate'], False)
 
     def test_expand_destination_string(self):
-        with _queues(self.app, foo=self.a_queue, bar=self.b_queue):
-            x = Router(self.app, {}, self.app.amqp.queues)
-            dest = x.expand_destination('foo')
-            self.assertEqual(dest['queue'].name, 'foo')
+        set_queues(self.app, foo=self.a_queue, bar=self.b_queue)
+        x = Router(self.app, {}, self.app.amqp.queues)
+        dest = x.expand_destination('foo')
+        self.assertEqual(dest['queue'].name, 'foo')
 
     def test_lookup_paths_traversed(self):
-        with _queues(self.app, foo=self.a_queue, bar=self.b_queue, **{
-                self.app.conf.CELERY_DEFAULT_QUEUE: self.d_queue}):
-            R = routes.prepare((
-                {'celery.xaza': {'queue': 'bar'}},
-                {self.mytask.name: {'queue': 'foo'}}
-            ))
-            router = Router(self.app, R, self.app.amqp.queues)
-            self.assertEqual(router.route({}, self.mytask.name,
-                             args=[1, 2], kwargs={})['queue'].name, 'foo')
-            self.assertEqual(
-                router.route({}, 'celery.poza')['queue'].name,
-                self.app.conf.CELERY_DEFAULT_QUEUE,
-            )
+        set_queues(
+            self.app, foo=self.a_queue, bar=self.b_queue,
+            **{self.app.conf.CELERY_DEFAULT_QUEUE: self.d_queue}
+        )
+        R = routes.prepare((
+            {'celery.xaza': {'queue': 'bar'}},
+            {self.mytask.name: {'queue': 'foo'}}
+        ))
+        router = Router(self.app, R, self.app.amqp.queues)
+        self.assertEqual(router.route({}, self.mytask.name,
+                         args=[1, 2], kwargs={})['queue'].name, 'foo')
+        self.assertEqual(
+            router.route({}, 'celery.poza')['queue'].name,
+            self.app.conf.CELERY_DEFAULT_QUEUE,
+        )
 
 
 class test_prepare(AppCase):

+ 2 - 2
celery/tests/app/test_schedules.py

@@ -365,8 +365,8 @@ class test_crontab_remaining_estimate(AppCase):
 
     def test_not_weekmonthdayyear(self):
         next = self.next_ocurrance(
-            crontab(minute=[5, 42], day_of_week='fri,sat',
-                    day_of_month=29, month_of_year='2-10'),
+            self.crontab(minute=[5, 42], day_of_week='fri,sat',
+                         day_of_month=29, month_of_year='2-10'),
             datetime(2010, 1, 28, 14, 30, 15),
         )
         self.assertEqual(next, datetime(2010, 5, 29, 0, 5))

+ 2 - 2
celery/tests/app/test_utils.py

@@ -4,10 +4,10 @@ from collections import Mapping, MutableMapping
 
 from celery.app.utils import Settings, bugreport
 
-from celery.tests.case import AppCase, Case, Mock
+from celery.tests.case import AppCase, Mock
 
 
-class TestSettings(Case):
+class TestSettings(AppCase):
     """
     Tests of celery.app.utils.Settings
     """

+ 7 - 8
celery/tests/backends/test_amqp.py

@@ -16,7 +16,9 @@ from celery.exceptions import TimeoutError
 from celery.five import Empty, Queue, range
 from celery.utils import uuid
 
-from celery.tests.case import AppCase, sleepdeprived, Mock
+from celery.tests.case import (
+    AppCase, Mock, depends_on_current_app, sleepdeprived,
+)
 
 
 class SomeClass(object):
@@ -43,6 +45,7 @@ class test_AMQPBackend(AppCase):
         self.assertTrue(tb2._cache.get(tid))
         self.assertTrue(tb2.get_result(tid), 42)
 
+    @depends_on_current_app
     def test_pickleable(self):
         self.assertTrue(loads(dumps(self.create_backend())))
 
@@ -322,14 +325,10 @@ class test_AMQPBackend(AppCase):
     def test_no_expires(self):
         b = self.create_backend(expires=None)
         app = self.app
-        prev = app.conf.CELERY_TASK_RESULT_EXPIRES
         app.conf.CELERY_TASK_RESULT_EXPIRES = None
-        try:
-            b = self.create_backend(expires=None)
-            with self.assertRaises(KeyError):
-                b.queue_arguments['x-expires']
-        finally:
-            app.conf.CELERY_TASK_RESULT_EXPIRES = prev
+        b = self.create_backend(expires=None)
+        with self.assertRaises(KeyError):
+            b.queue_arguments['x-expires']
 
     def test_process_cleanup(self):
         self.create_backend().process_cleanup()

+ 13 - 11
celery/tests/backends/test_backends.py

@@ -5,17 +5,19 @@ from mock import patch
 from celery import backends
 from celery.backends.amqp import AMQPBackend
 from celery.backends.cache import CacheBackend
-from celery.tests.case import AppCase
+from celery.tests.case import AppCase, depends_on_current_app
 
 
 class test_backends(AppCase):
 
     def test_get_backend_aliases(self):
-        expects = [('amqp', AMQPBackend),
-                   ('cache', CacheBackend)]
-        for expect_name, expect_cls in expects:
+        expects = [('amqp://', AMQPBackend),
+                   ('cache+memory://', CacheBackend)]
+
+        for url, expect_cls in expects:
+            backend, url = backends.get_backend_by_url(url, self.app.loader)
             self.assertIsInstance(
-                backends.get_backend_cls(expect_name)(app=self.app),
+                backend(app=self.app, url=url),
                 expect_cls,
             )
 
@@ -23,22 +25,22 @@ class test_backends(AppCase):
         backends.get_backend_cls.clear()
         hits = backends.get_backend_cls.hits
         misses = backends.get_backend_cls.misses
-        self.assertTrue(backends.get_backend_cls('amqp'))
+        self.assertTrue(backends.get_backend_cls('amqp', self.app.loader))
         self.assertEqual(backends.get_backend_cls.misses, misses + 1)
-        self.assertTrue(backends.get_backend_cls('amqp'))
+        self.assertTrue(backends.get_backend_cls('amqp', self.app.loader))
         self.assertEqual(backends.get_backend_cls.hits, hits + 1)
 
     def test_unknown_backend(self):
         with self.assertRaises(ImportError):
-            backends.get_backend_cls('fasodaopjeqijwqe')
+            backends.get_backend_cls('fasodaopjeqijwqe', self.app.loader)
 
+    @depends_on_current_app
     def test_default_backend(self):
-        self.app.set_current()  # XXX compat test
         self.assertEqual(backends.default_backend, self.app.backend)
 
     def test_backend_by_url(self, url='redis://localhost/1'):
         from celery.backends.redis import RedisBackend
-        backend, url_ = backends.get_backend_by_url(url)
+        backend, url_ = backends.get_backend_by_url(url, self.app.loader)
         self.assertIs(backend, RedisBackend)
         self.assertEqual(url_, url)
 
@@ -46,4 +48,4 @@ class test_backends(AppCase):
         with patch('celery.backends.symbol_by_name') as sbn:
             sbn.side_effect = ValueError()
             with self.assertRaises(ValueError):
-                backends.get_backend_cls('xxx.xxx:foo')
+                backends.get_backend_cls('xxx.xxx:foo', self.app.loader)

+ 10 - 14
celery/tests/backends/test_base.py

@@ -9,7 +9,6 @@ from nose import SkipTest
 
 from celery.exceptions import ChordError
 from celery.five import items, range
-from celery.result import AsyncResult, GroupResult
 from celery.utils import serialization
 from celery.utils.serialization import subclass_exception
 from celery.utils.serialization import find_pickleable_exception as fnpe
@@ -24,7 +23,7 @@ from celery.backends.base import (
 )
 from celery.utils import uuid
 
-from celery.tests.case import AppCase, Case
+from celery.tests.case import AppCase
 
 
 class wrapobject(object):
@@ -66,18 +65,15 @@ class test_BaseBackend_interface(AppCase):
         self.b.on_chord_part_return(None)
 
     def test_on_chord_apply(self, unlock='celery.chord_unlock'):
-        p, self.app.tasks[unlock] = self.app.tasks.get(unlock), Mock()
-        try:
-            self.b.on_chord_apply(
-                'dakj221', 'sdokqweok',
-                result=[self.app.AsyncResult(x) for x in [1, 2, 3]],
-            )
-            self.assertTrue(self.app.tasks[unlock].apply_async.call_count)
-        finally:
-            self.app.tasks[unlock] = p
+        self.app.tasks[unlock] = Mock()
+        self.b.on_chord_apply(
+            'dakj221', 'sdokqweok',
+            result=[self.app.AsyncResult(x) for x in [1, 2, 3]],
+        )
+        self.assertTrue(self.app.tasks[unlock].apply_async.call_count)
 
 
-class test_exception_pickle(Case):
+class test_exception_pickle(AppCase):
 
     def test_oldstyle(self):
         if Oldstyle is None:
@@ -224,7 +220,7 @@ class test_BaseBackend_dict(AppCase):
             self.assertTrue(args[2])
 
     def test_prepare_value_serializes_group_result(self):
-        g = GroupResult('group_id', [AsyncResult('foo')])
+        g = self.app.GroupResult('group_id', [self.app.AsyncResult('foo')])
         self.assertIsInstance(self.b.prepare_value(g), (list, tuple))
 
     def test_is_cached(self):
@@ -286,7 +282,7 @@ class test_KeyValueStoreBackend(AppCase):
     @contextmanager
     def _chord_part_context(self, b):
 
-        @self.app.task()
+        @self.app.task(shared=False)
         def callback(result):
             pass
 

+ 7 - 13
celery/tests/backends/test_cache.py

@@ -8,12 +8,11 @@ from contextlib import contextmanager
 from kombu.utils.encoding import str_to_bytes
 from mock import Mock, patch
 
+from celery import subtask
 from celery import states
 from celery.backends.cache import CacheBackend, DummyClient
 from celery.exceptions import ImproperlyConfigured
 from celery.five import items, string, text_t
-from celery.result import AsyncResult
-from celery.task import subtask
 from celery.utils import uuid
 
 from celery.tests.case import AppCase, mask_modules, reset_modules
@@ -32,14 +31,9 @@ class test_CacheBackend(AppCase):
         self.tid = uuid()
 
     def test_no_backend(self):
-        prev, self.app.conf.CELERY_CACHE_BACKEND = (
-            self.app.conf.CELERY_CACHE_BACKEND, None,
-        )
-        try:
-            with self.assertRaises(ImproperlyConfigured):
-                CacheBackend(backend=None, app=self.app)
-        finally:
-            self.app.conf.CELERY_CACHE_BACKEND = prev
+        self.app.conf.CELERY_CACHE_BACKEND = None
+        with self.assertRaises(ImproperlyConfigured):
+            CacheBackend(backend=None, app=self.app)
 
     def test_mark_as_done(self):
         self.assertEqual(self.tb.get_status(self.tid), states.PENDING)
@@ -67,7 +61,7 @@ class test_CacheBackend(AppCase):
 
     def test_on_chord_apply(self):
         tb = CacheBackend(backend='memory://', app=self.app)
-        gid, res = uuid(), [AsyncResult(uuid()) for _ in range(3)]
+        gid, res = uuid(), [self.app.AsyncResult(uuid()) for _ in range(3)]
         tb.on_chord_apply(gid, {}, result=res)
 
     @patch('celery.result.GroupResult.restore')
@@ -83,7 +77,7 @@ class test_CacheBackend(AppCase):
         self.app.tasks['foobarbaz'] = task
         task.request.chord = subtask(task)
 
-        gid, res = uuid(), [AsyncResult(uuid()) for _ in range(3)]
+        gid, res = uuid(), [self.app.AsyncResult(uuid()) for _ in range(3)]
         task.request.group = gid
         tb.on_chord_apply(gid, {}, result=res)
 
@@ -104,7 +98,7 @@ class test_CacheBackend(AppCase):
 
     def test_forget(self):
         self.tb.mark_as_done(self.tid, {'foo': 'bar'})
-        x = AsyncResult(self.tid, backend=self.tb)
+        x = self.app.AsyncResult(self.tid, backend=self.tb)
         x.forget()
         self.assertIsNone(x.result)
 

+ 22 - 25
celery/tests/backends/test_cassandra.py

@@ -5,10 +5,9 @@ import socket
 from mock import Mock
 from pickle import loads, dumps
 
-from celery import Celery
 from celery import states
 from celery.exceptions import ImproperlyConfigured
-from celery.tests.case import AppCase, mock_module
+from celery.tests.case import AppCase, mock_module, depends_on_current_app
 
 
 class Object(object):
@@ -46,6 +45,13 @@ def install_exceptions(mod):
 
 class test_CassandraBackend(AppCase):
 
+    def setup(self):
+        self.app.conf.update(
+            CASSANDRA_SERVERS=['example.com'],
+            CASSANDRA_KEYSPACE='keyspace',
+            CASSANDRA_COLUMN_FAMILY='columns',
+        )
+
     def test_init_no_pycassa(self):
         with mock_module('pycassa'):
             from celery.backends import cassandra as mod
@@ -56,13 +62,6 @@ class test_CassandraBackend(AppCase):
             finally:
                 mod.pycassa = prev
 
-    def get_app(self):
-        celery = Celery(set_as_current=False)
-        celery.conf.CASSANDRA_SERVERS = ['example.com']
-        celery.conf.CASSANDRA_KEYSPACE = 'keyspace'
-        celery.conf.CASSANDRA_COLUMN_FAMILY = 'columns'
-        return celery
-
     def test_init_with_and_without_LOCAL_QUROM(self):
         with mock_module('pycassa'):
             from celery.backends import cassandra as mod
@@ -71,23 +70,25 @@ class test_CassandraBackend(AppCase):
             cons = mod.pycassa.ConsistencyLevel = Object()
             cons.LOCAL_QUORUM = 'foo'
 
-            app = self.get_app()
-            app.conf.CASSANDRA_READ_CONSISTENCY = 'LOCAL_FOO'
-            app.conf.CASSANDRA_WRITE_CONSISTENCY = 'LOCAL_FOO'
+            self.app.conf.CASSANDRA_READ_CONSISTENCY = 'LOCAL_FOO'
+            self.app.conf.CASSANDRA_WRITE_CONSISTENCY = 'LOCAL_FOO'
 
-            mod.CassandraBackend(app=app)
+            mod.CassandraBackend(app=self.app)
             cons.LOCAL_FOO = 'bar'
-            mod.CassandraBackend(app=app)
+            mod.CassandraBackend(app=self.app)
 
             # no servers raises ImproperlyConfigured
             with self.assertRaises(ImproperlyConfigured):
-                app.conf.CASSANDRA_SERVERS = None
-                mod.CassandraBackend(app=app, keyspace='b', column_family='c')
+                self.app.conf.CASSANDRA_SERVERS = None
+                mod.CassandraBackend(
+                    app=self.app, keyspace='b', column_family='c',
+                )
 
+    @depends_on_current_app
     def test_reduce(self):
         with mock_module('pycassa'):
             from celery.backends.cassandra import CassandraBackend
-            self.assertTrue(loads(dumps(CassandraBackend(app=self.get_app()))))
+            self.assertTrue(loads(dumps(CassandraBackend(app=self.app))))
 
     def test_get_task_meta_for(self):
         with mock_module('pycassa'):
@@ -96,8 +97,7 @@ class test_CassandraBackend(AppCase):
             install_exceptions(mod.pycassa)
             mod.Thrift = Mock()
             install_exceptions(mod.Thrift)
-            app = self.get_app()
-            x = mod.CassandraBackend(app=app)
+            x = mod.CassandraBackend(app=self.app)
             Get_Column = x._get_column_family = Mock()
             get_column = Get_Column.return_value = Mock()
             get = get_column.get
@@ -155,8 +155,7 @@ class test_CassandraBackend(AppCase):
             install_exceptions(mod.pycassa)
             mod.Thrift = Mock()
             install_exceptions(mod.Thrift)
-            app = self.get_app()
-            x = mod.CassandraBackend(app=app)
+            x = mod.CassandraBackend(app=self.app)
             Get_Column = x._get_column_family = Mock()
             cf = Get_Column.return_value = Mock()
             x.detailed_mode = False
@@ -171,8 +170,7 @@ class test_CassandraBackend(AppCase):
     def test_process_cleanup(self):
         with mock_module('pycassa'):
             from celery.backends import cassandra as mod
-            app = self.get_app()
-            x = mod.CassandraBackend(app=app)
+            x = mod.CassandraBackend(app=self.app)
             x._column_family = None
             x.process_cleanup()
 
@@ -185,8 +183,7 @@ class test_CassandraBackend(AppCase):
             from celery.backends import cassandra as mod
             mod.pycassa = Mock()
             install_exceptions(mod.pycassa)
-            app = self.get_app()
-            x = mod.CassandraBackend(app=app)
+            x = mod.CassandraBackend(app=self.app)
             self.assertTrue(x._get_column_family())
             self.assertIsNotNone(x._column_family)
             self.assertIs(x._get_column_family(), x._column_family)

+ 46 - 58
celery/tests/backends/test_couchbase.py

@@ -3,7 +3,6 @@ from __future__ import absolute_import
 from mock import MagicMock, Mock, patch, sentinel
 from nose import SkipTest
 
-from celery import Celery
 from celery.backends import couchbase as module
 from celery.backends.couchbase import CouchBaseBackend
 from celery.exceptions import ImproperlyConfigured
@@ -36,19 +35,16 @@ class test_CouchBaseBackend(AppCase):
 
     def test_init_no_settings(self):
         """test init no settings"""
-        celery = Celery(set_as_current=False)
-        celery.conf.CELERY_COUCHBASE_BACKEND_SETTINGS = []
+        self.app.conf.CELERY_COUCHBASE_BACKEND_SETTINGS = []
         with self.assertRaises(ImproperlyConfigured):
-            CouchBaseBackend(app=celery)
+            CouchBaseBackend(app=self.app)
 
     def test_init_settings_is_None(self):
         """Test init settings is None"""
-        celery = Celery(set_as_current=False)
-        celery.conf.CELERY_COUCHBASE_BACKEND_SETTINGS = None
-        CouchBaseBackend(app=celery)
+        self.app.conf.CELERY_COUCHBASE_BACKEND_SETTINGS = None
+        CouchBaseBackend(app=self.app)
 
     def test_get_connection_connection_exists(self):
-        """Test get existing connection"""
         with patch('couchbase.connection.Connection') as mock_Connection:
             self.backend._connection = sentinel._connection
 
@@ -58,89 +54,81 @@ class test_CouchBaseBackend(AppCase):
             self.assertFalse(mock_Connection.called)
 
     def test_get(self):
-        """Test get
+        """test_get
 
         CouchBaseBackend.get should return  and take two params
         db conn to couchbase is mocked.
         TODO Should test on key not exists
 
         """
-        with Celery(set_as_current=False) as app:
-            app.conf.CELERY_COUCHBASE_BACKEND_SETTINGS = {}
-
-            x = CouchBaseBackend(app=app)
-            x._connection = Mock()
-            mocked_get = x._connection.get = Mock()
-            mocked_get.return_value.value = sentinel.retval
-            # should return None
-            self.assertEqual(x.get('1f3fab'), sentinel.retval)
-            x._connection.get.assert_called_once_with('1f3fab')
-
-    # betta
+        self.app.conf.CELERY_COUCHBASE_BACKEND_SETTINGS = {}
+        x = CouchBaseBackend(app=self.app)
+        x._connection = Mock()
+        mocked_get = x._connection.get = Mock()
+        mocked_get.return_value.value = sentinel.retval
+        # should return None
+        self.assertEqual(x.get('1f3fab'), sentinel.retval)
+        x._connection.get.assert_called_once_with('1f3fab')
+
     def test_set(self):
-        """Test set
+        """test_set
 
         CouchBaseBackend.set should return None and take two params
         db conn to couchbase is mocked.
 
         """
-        with Celery(set_as_current=False) as app:
-            app.conf.CELERY_COUCHBASE_BACKEND_SETTINGS = None
-            x = CouchBaseBackend(app=app)
-            x._connection = MagicMock()
-            x._connection.set = MagicMock()
-            # should return None
-            self.assertIsNone(x.set(sentinel.key, sentinel.value))
+        self.app.conf.CELERY_COUCHBASE_BACKEND_SETTINGS = None
+        x = CouchBaseBackend(app=self.app)
+        x._connection = MagicMock()
+        x._connection.set = MagicMock()
+        # should return None
+        self.assertIsNone(x.set(sentinel.key, sentinel.value))
 
     def test_delete(self):
-        """Test delete
+        """test_delete
 
         CouchBaseBackend.delete should return and take two params
         db conn to couchbase is mocked.
         TODO Should test on key not exists
 
         """
-        with Celery(set_as_current=False) as app:
-            app.conf.CELERY_COUCHBASE_BACKEND_SETTINGS = {}
-            x = CouchBaseBackend(app=app)
-            x._connection = Mock()
-            mocked_delete = x._connection.delete = Mock()
-            mocked_delete.return_value = None
-            # should return None
-            self.assertIsNone(x.delete('1f3fab'))
-            x._connection.delete.assert_called_once_with('1f3fab')
+        self.app.conf.CELERY_COUCHBASE_BACKEND_SETTINGS = {}
+        x = CouchBaseBackend(app=self.app)
+        x._connection = Mock()
+        mocked_delete = x._connection.delete = Mock()
+        mocked_delete.return_value = None
+        # should return None
+        self.assertIsNone(x.delete('1f3fab'))
+        x._connection.delete.assert_called_once_with('1f3fab')
 
     def test_config_params(self):
-        """test celery.conf.CELERY_COUCHBASE_BACKEND_SETTINGS
+        """test_config_params
 
         celery.conf.CELERY_COUCHBASE_BACKEND_SETTINGS is properly set
         """
-        with Celery(set_as_current=False) as app:
-            app.conf.CELERY_COUCHBASE_BACKEND_SETTINGS = {
-                'bucket': 'mycoolbucket',
-                'host': ['here.host.com', 'there.host.com'],
-                'username': 'johndoe',
-                'password': 'mysecret',
-                'port': '1234',
-            }
-            x = CouchBaseBackend(app=app)
-            self.assertEqual(x.bucket, 'mycoolbucket')
-            self.assertEqual(x.host, ['here.host.com', 'there.host.com'],)
-            self.assertEqual(x.username, 'johndoe',)
-            self.assertEqual(x.password, 'mysecret')
-            self.assertEqual(x.port, 1234)
+        self.app.conf.CELERY_COUCHBASE_BACKEND_SETTINGS = {
+            'bucket': 'mycoolbucket',
+            'host': ['here.host.com', 'there.host.com'],
+            'username': 'johndoe',
+            'password': 'mysecret',
+            'port': '1234',
+        }
+        x = CouchBaseBackend(app=self.app)
+        self.assertEqual(x.bucket, 'mycoolbucket')
+        self.assertEqual(x.host, ['here.host.com', 'there.host.com'],)
+        self.assertEqual(x.username, 'johndoe',)
+        self.assertEqual(x.password, 'mysecret')
+        self.assertEqual(x.port, 1234)
 
     def test_backend_by_url(self, url='couchbase://myhost/mycoolbucket'):
-        """test get backend by url"""
         from celery.backends.couchbase import CouchBaseBackend
-        backend, url_ = backends.get_backend_by_url(url)
+        backend, url_ = backends.get_backend_by_url(url, self.app.loader)
         self.assertIs(backend, CouchBaseBackend)
         self.assertEqual(url_, url)
 
     def test_backend_params_by_url(self):
-        """test get backend params by url"""
         url = 'couchbase://johndoe:mysecret@myhost:123/mycoolbucket'
-        with Celery(set_as_current=False, backend=url) as app:
+        with self.Celery(backend=url) as app:
             x = app.backend
             self.assertEqual(x.bucket, "mycoolbucket")
             self.assertEqual(x.host, "myhost")

+ 20 - 22
celery/tests/backends/test_database.py

@@ -7,11 +7,11 @@ from pickle import loads, dumps
 
 from celery import states
 from celery.exceptions import ImproperlyConfigured
-from celery.result import AsyncResult
 from celery.utils import uuid
 
 from celery.tests.case import (
     AppCase,
+    depends_on_current_app,
     mask_modules,
     skip_if_pypy,
     skip_if_jython,
@@ -39,6 +39,7 @@ class test_DatabaseBackend(AppCase):
     def setup(self):
         if DatabaseBackend is None:
             raise SkipTest('sqlalchemy not installed')
+        self.uri = 'sqlite:///test.db'
 
     def test_retry_helper(self):
         from celery.backends.database import OperationalError
@@ -61,20 +62,16 @@ class test_DatabaseBackend(AppCase):
                 _sqlalchemy_installed()
 
     def test_missing_dburi_raises_ImproperlyConfigured(self):
-        conf = self.app.conf
-        prev, conf.CELERY_RESULT_DBURI = conf.CELERY_RESULT_DBURI, None
-        try:
-            with self.assertRaises(ImproperlyConfigured):
-                DatabaseBackend(app=self.app)
-        finally:
-            conf.CELERY_RESULT_DBURI = prev
+        self.app.conf.CELERY_RESULT_DBURI = None
+        with self.assertRaises(ImproperlyConfigured):
+            DatabaseBackend(app=self.app)
 
     def test_missing_task_id_is_PENDING(self):
-        tb = DatabaseBackend(app=self.app)
+        tb = DatabaseBackend(self.uri, app=self.app)
         self.assertEqual(tb.get_status('xxx-does-not-exist'), states.PENDING)
 
     def test_missing_task_meta_is_dict_with_pending(self):
-        tb = DatabaseBackend(app=self.app)
+        tb = DatabaseBackend(self.uri, app=self.app)
         self.assertDictContainsSubset({
             'status': states.PENDING,
             'task_id': 'xxx-does-not-exist-at-all',
@@ -83,7 +80,7 @@ class test_DatabaseBackend(AppCase):
         }, tb.get_task_meta('xxx-does-not-exist-at-all'))
 
     def test_mark_as_done(self):
-        tb = DatabaseBackend(app=self.app)
+        tb = DatabaseBackend(self.uri, app=self.app)
 
         tid = uuid()
 
@@ -95,7 +92,7 @@ class test_DatabaseBackend(AppCase):
         self.assertEqual(tb.get_result(tid), 42)
 
     def test_is_pickled(self):
-        tb = DatabaseBackend(app=self.app)
+        tb = DatabaseBackend(self.uri, app=self.app)
 
         tid2 = uuid()
         result = {'foo': 'baz', 'bar': SomeClass(12345)}
@@ -106,19 +103,19 @@ class test_DatabaseBackend(AppCase):
         self.assertEqual(rindb.get('bar').data, 12345)
 
     def test_mark_as_started(self):
-        tb = DatabaseBackend(app=self.app)
+        tb = DatabaseBackend(self.uri, app=self.app)
         tid = uuid()
         tb.mark_as_started(tid)
         self.assertEqual(tb.get_status(tid), states.STARTED)
 
     def test_mark_as_revoked(self):
-        tb = DatabaseBackend(app=self.app)
+        tb = DatabaseBackend(self.uri, app=self.app)
         tid = uuid()
         tb.mark_as_revoked(tid)
         self.assertEqual(tb.get_status(tid), states.REVOKED)
 
     def test_mark_as_retry(self):
-        tb = DatabaseBackend(app=self.app)
+        tb = DatabaseBackend(self.uri, app=self.app)
         tid = uuid()
         try:
             raise KeyError('foo')
@@ -131,7 +128,7 @@ class test_DatabaseBackend(AppCase):
             self.assertEqual(tb.get_traceback(tid), trace)
 
     def test_mark_as_failure(self):
-        tb = DatabaseBackend(app=self.app)
+        tb = DatabaseBackend(self.uri, app=self.app)
 
         tid3 = uuid()
         try:
@@ -145,24 +142,25 @@ class test_DatabaseBackend(AppCase):
             self.assertEqual(tb.get_traceback(tid3), trace)
 
     def test_forget(self):
-        tb = DatabaseBackend(backend='memory://', app=self.app)
+        tb = DatabaseBackend(self.uri, backend='memory://', app=self.app)
         tid = uuid()
         tb.mark_as_done(tid, {'foo': 'bar'})
         tb.mark_as_done(tid, {'foo': 'bar'})
-        x = AsyncResult(tid, backend=tb)
+        x = self.app.AsyncResult(tid, backend=tb)
         x.forget()
         self.assertIsNone(x.result)
 
     def test_process_cleanup(self):
-        tb = DatabaseBackend(app=self.app)
+        tb = DatabaseBackend(self.uri, app=self.app)
         tb.process_cleanup()
 
+    @depends_on_current_app
     def test_reduce(self):
-        tb = DatabaseBackend(app=self.app)
+        tb = DatabaseBackend(self.uri, app=self.app)
         self.assertTrue(loads(dumps(tb)))
 
     def test_save__restore__delete_group(self):
-        tb = DatabaseBackend(app=self.app)
+        tb = DatabaseBackend(self.uri, app=self.app)
 
         tid = uuid()
         res = {'something': 'special'}
@@ -177,7 +175,7 @@ class test_DatabaseBackend(AppCase):
         self.assertIsNone(tb.restore_group('xxx-nonexisting-id'))
 
     def test_cleanup(self):
-        tb = DatabaseBackend(app=self.app)
+        tb = DatabaseBackend(self.uri, app=self.app)
         for i in range(10):
             tb.mark_as_done(uuid(), 42)
             tb.save_group(uuid(), {'foo': 'bar'})

+ 6 - 8
celery/tests/backends/test_mongodb.py

@@ -7,12 +7,11 @@ from mock import MagicMock, Mock, patch, sentinel
 from nose import SkipTest
 from pickle import loads, dumps
 
-from celery import Celery
 from celery import states
 from celery.backends import mongodb as module
 from celery.backends.mongodb import MongoBackend, Bunch, pymongo
 from celery.exceptions import ImproperlyConfigured
-from celery.tests.case import AppCase
+from celery.tests.case import AppCase, depends_on_current_app
 
 COLLECTION = 'taskmeta_celery'
 TASK_ID = str(uuid.uuid1())
@@ -58,15 +57,13 @@ class test_MongoBackend(AppCase):
             module.pymongo = prev
 
     def test_init_no_settings(self):
-        celery = Celery(set_as_current=False)
-        celery.conf.CELERY_MONGODB_BACKEND_SETTINGS = []
+        self.app.conf.CELERY_MONGODB_BACKEND_SETTINGS = []
         with self.assertRaises(ImproperlyConfigured):
-            MongoBackend(app=celery)
+            MongoBackend(app=self.app)
 
     def test_init_settings_is_None(self):
-        celery = Celery(set_as_current=False)
-        celery.conf.CELERY_MONGODB_BACKEND_SETTINGS = None
-        MongoBackend(app=celery)
+        self.app.conf.CELERY_MONGODB_BACKEND_SETTINGS = None
+        MongoBackend(app=self.app)
 
     def test_restore_group_no_entry(self):
         x = MongoBackend(app=self.app)
@@ -75,6 +72,7 @@ class test_MongoBackend(AppCase):
         fo.return_value = None
         self.assertIsNone(x._restore_group('1f3fab'))
 
+    @depends_on_current_app
     def test_reduce(self):
         x = MongoBackend(app=self.app)
         self.assertTrue(loads(dumps(x)))

+ 14 - 21
celery/tests/backends/test_redis.py

@@ -8,14 +8,13 @@ from pickle import loads, dumps
 
 from kombu.utils import cached_property, uuid
 
+from celery import subtask
 from celery import states
 from celery.datastructures import AttributeDict
 from celery.exceptions import ImproperlyConfigured
-from celery.result import AsyncResult
-from celery.task import subtask
 from celery.utils.timeutils import timedelta_seconds
 
-from celery.tests.case import AppCase
+from celery.tests.case import AppCase, depends_on_current_app
 
 
 class Redis(object):
@@ -85,6 +84,7 @@ class test_RedisBackend(AppCase):
 
         self.MockBackend = MockBackend
 
+    @depends_on_current_app
     def test_reduce(self):
         try:
             from celery.backends.redis import RedisBackend
@@ -104,25 +104,18 @@ class test_RedisBackend(AppCase):
         self.assertEqual(x.db, '1')
 
     def test_conf_raises_KeyError(self):
-        conf = AttributeDict({'CELERY_RESULT_SERIALIZER': 'json',
-                              'CELERY_MAX_CACHED_RESULTS': 1,
-                              'CELERY_ACCEPT_CONTENT': ['json'],
-                              'CELERY_TASK_RESULT_EXPIRES': None})
-        prev, self.app.conf = self.app.conf, conf
-        try:
-            self.MockBackend(app=self.app)
-        finally:
-            self.app.conf = prev
+        self.app.conf = AttributeDict({
+            'CELERY_RESULT_SERIALIZER': 'json',
+            'CELERY_MAX_CACHED_RESULTS': 1,
+            'CELERY_ACCEPT_CONTENT': ['json'],
+            'CELERY_TASK_RESULT_EXPIRES': None,
+        })
+        self.MockBackend(app=self.app)
 
     def test_expires_defaults_to_config(self):
-        conf = self.app.conf
-        prev = conf.CELERY_TASK_RESULT_EXPIRES
-        conf.CELERY_TASK_RESULT_EXPIRES = 10
-        try:
-            b = self.Backend(expires=None, app=self.app)
-            self.assertEqual(b.expires, 10)
-        finally:
-            conf.CELERY_TASK_RESULT_EXPIRES = prev
+        self.app.conf.CELERY_TASK_RESULT_EXPIRES = 10
+        b = self.Backend(expires=None, app=self.app)
+        self.assertEqual(b.expires, 10)
 
     def test_expires_is_int(self):
         b = self.Backend(expires=48, app=self.app)
@@ -140,7 +133,7 @@ class test_RedisBackend(AppCase):
     def test_on_chord_apply(self):
         self.Backend(app=self.app).on_chord_apply(
             'group_id', {},
-            result=[AsyncResult(x) for x in [1, 2, 3]],
+            result=[self.app.AsyncResult(x) for x in [1, 2, 3]],
         )
 
     def test_mget(self):

+ 0 - 2
celery/tests/bin/test_amqp.py

@@ -2,7 +2,6 @@ from __future__ import absolute_import
 
 from mock import Mock, patch
 
-from celery import Celery
 from celery.bin.amqp import (
     AMQPAdmin,
     AMQShell,
@@ -18,7 +17,6 @@ class test_AMQShell(AppCase):
 
     def setup(self):
         self.fh = WhateverIO()
-        self.app = Celery(broker='memory://', set_as_current=False)
         self.adm = self.create_adm()
         self.shell = AMQShell(connect=self.adm.connect, out=self.fh)
 

+ 16 - 16
celery/tests/bin/test_base.py

@@ -10,7 +10,9 @@ from celery.bin.base import (
     Extensions,
     HelpFormatter,
 )
-from celery.tests.case import AppCase, Case, override_stdouts
+from celery.tests.case import (
+    AppCase, override_stdouts, depends_on_current_app,
+)
 
 
 class Object(object):
@@ -36,7 +38,7 @@ class MockCommand(Command):
         return args, kwargs
 
 
-class test_Extensions(Case):
+class test_Extensions(AppCase):
 
     def test_load(self):
         with patch('pkg_resources.iter_entry_points') as iterep:
@@ -65,7 +67,7 @@ class test_Extensions(Case):
                     e.load()
 
 
-class test_HelpFormatter(Case):
+class test_HelpFormatter(AppCase):
 
     def test_format_epilog(self):
         f = HelpFormatter()
@@ -276,21 +278,19 @@ class test_Command(AppCase):
         cmd.show_body = False
         cmd.say_chat('->', 'foo', 'body')
 
+    @depends_on_current_app
     def test_with_cmdline_config(self):
         cmd = MockCommand()
-        try:
-            cmd.enable_config_from_cmdline = True
-            cmd.namespace = 'celeryd'
-            rest = cmd.setup_app_from_commandline(argv=[
-                '--loglevel=INFO', '--',
-                'broker.url=amqp://broker.example.com',
-                '.prefetch_multiplier=100'])
-            self.assertEqual(cmd.app.conf.BROKER_URL,
-                             'amqp://broker.example.com')
-            self.assertEqual(cmd.app.conf.CELERYD_PREFETCH_MULTIPLIER, 100)
-            self.assertListEqual(rest, ['--loglevel=INFO'])
-        finally:
-            cmd.app.conf.BROKER_URL = 'memory://'
+        cmd.enable_config_from_cmdline = True
+        cmd.namespace = 'celeryd'
+        rest = cmd.setup_app_from_commandline(argv=[
+            '--loglevel=INFO', '--',
+            'broker.url=amqp://broker.example.com',
+            '.prefetch_multiplier=100'])
+        self.assertEqual(cmd.app.conf.BROKER_URL,
+                         'amqp://broker.example.com')
+        self.assertEqual(cmd.app.conf.CELERYD_PREFETCH_MULTIPLIER, 100)
+        self.assertListEqual(rest, ['--loglevel=INFO'])
 
     def test_find_app(self):
         cmd = MockCommand()

+ 5 - 7
celery/tests/bin/test_beat.py

@@ -70,13 +70,11 @@ class test_Beat(AppCase):
         self.assertEqual(b2.loglevel, logging.DEBUG)
 
     def test_colorize(self):
-        from celery import Celery
-        app = Celery(set_as_current=False)
-        app.log.setup = Mock()
-        b = beatapp.Beat(app=app, no_color=True)
+        self.app.log.setup = Mock()
+        b = beatapp.Beat(app=self.app, no_color=True)
         b.setup_logging()
-        self.assertTrue(app.log.setup.called)
-        self.assertEqual(app.log.setup.call_args[1]['colorize'], False)
+        self.assertTrue(self.app.log.setup.called)
+        self.assertEqual(self.app.log.setup.call_args[1]['colorize'], False)
 
     def test_init_loader(self):
         b = beatapp.Beat(app=self.app)
@@ -179,7 +177,7 @@ class test_div(AppCase):
     def test_main(self):
         sys.argv = [sys.argv[0], '-s', 'foo']
         try:
-            beat_bin.main()
+            beat_bin.main(app=self.app)
             self.assertTrue(MockBeat.running)
         finally:
             MockBeat.running = False

+ 30 - 25
celery/tests/bin/test_celery.py

@@ -7,7 +7,6 @@ from datetime import datetime
 from mock import Mock, patch
 
 from celery import __main__
-from celery import task
 from celery.platforms import EX_FAILURE, EX_USAGE, EX_OK
 from celery.bin.base import Error
 from celery.bin.celery import (
@@ -30,15 +29,10 @@ from celery.bin.celery import (
     command,
 )
 
-from celery.tests.case import AppCase, Case, WhateverIO, override_stdouts
+from celery.tests.case import AppCase, WhateverIO, override_stdouts
 
 
-@task()
-def add(x, y):
-    return x + y
-
-
-class test__main__(Case):
+class test__main__(AppCase):
 
     def test_warn_deprecated(self):
         with override_stdouts() as (stdout, _):
@@ -167,28 +161,35 @@ class test_list(AppCase):
 
 class test_call(AppCase):
 
+    def setup(self):
+
+        @self.app.task(shared=False)
+        def add(x, y):
+            return x + y
+        self.add = add
+
     @patch('celery.app.base.Celery.send_task')
     def test_run(self, send_task):
         a = call(app=self.app, stderr=WhateverIO(), stdout=WhateverIO())
-        a.run('tasks.add')
+        a.run(self.add.name)
         self.assertTrue(send_task.called)
 
-        a.run('tasks.add',
+        a.run(self.add.name,
               args=dumps([4, 4]),
               kwargs=dumps({'x': 2, 'y': 2}))
         self.assertEqual(send_task.call_args[1]['args'], [4, 4])
         self.assertEqual(send_task.call_args[1]['kwargs'], {'x': 2, 'y': 2})
 
-        a.run('tasks.add', expires=10, countdown=10)
+        a.run(self.add.name, expires=10, countdown=10)
         self.assertEqual(send_task.call_args[1]['expires'], 10)
         self.assertEqual(send_task.call_args[1]['countdown'], 10)
 
         now = datetime.now()
         iso = now.isoformat()
-        a.run('tasks.add', expires=iso)
+        a.run(self.add.name, expires=iso)
         self.assertEqual(send_task.call_args[1]['expires'], now)
         with self.assertRaises(ValueError):
-            a.run('tasks.add', expires='foobaribazibar')
+            a.run(self.add.name, expires='foobaribazibar')
 
 
 class test_purge(AppCase):
@@ -208,6 +209,13 @@ class test_purge(AppCase):
 
 class test_result(AppCase):
 
+    def setup(self):
+
+        @self.app.task(shared=False)
+        def add(x, y):
+            return x + y
+        self.add = add
+
     def test_run(self):
         with patch('celery.result.AsyncResult.get') as get:
             out = WhateverIO()
@@ -217,11 +225,11 @@ class test_result(AppCase):
             self.assertIn('Jerry', out.getvalue())
 
             get.return_value = 'Elaine'
-            r.run('id', task=add.name)
+            r.run('id', task=self.add.name)
             self.assertIn('Elaine', out.getvalue())
 
             with patch('celery.result.AsyncResult.traceback') as tb:
-                r.run('id', task=add.name, traceback=True)
+                r.run('id', task=self.add.name, traceback=True)
                 self.assertIn(str(tb), out.getvalue())
 
 
@@ -417,15 +425,12 @@ class test_inspect(AppCase):
         self.assertTrue(inspect(app=self.app).epilog)
 
     def test_do_call_method_sql_transport_type(self):
-        prev, self.app.connection = self.app.connection, Mock()
-        try:
-            conn = self.app.connection.return_value = Mock(name='Connection')
-            conn.transport.driver_type = 'sql'
-            i = inspect(app=self.app)
-            with self.assertRaises(i.Error):
-                i.do_call_method(['ping'])
-        finally:
-            self.app.connection = prev
+        self.app.connection = Mock()
+        conn = self.app.connection.return_value = Mock(name='Connection')
+        conn.transport.driver_type = 'sql'
+        i = inspect(app=self.app)
+        with self.assertRaises(i.Error):
+            i.do_call_method(['ping'])
 
     def test_say_directions(self):
         i = inspect(self.app)
@@ -561,7 +566,7 @@ class test_main(AppCase):
         cmd.execute_from_commandline.assert_called_with(None)
 
 
-class test_compat(Case):
+class test_compat(AppCase):
 
     def test_compat_command_decorator(self):
         with patch('celery.bin.celery.CeleryCommand') as CC:

+ 10 - 9
celery/tests/bin/test_celeryd_detach.py

@@ -9,11 +9,11 @@ from celery.bin.celeryd_detach import (
     main,
 )
 
-from celery.tests.case import Case, override_stdouts
+from celery.tests.case import AppCase, override_stdouts
 
 
 if not IS_WINDOWS:
-    class test_detached(Case):
+    class test_detached(AppCase):
 
         @patch('celery.bin.celeryd_detach.detached')
         @patch('os.execv')
@@ -32,17 +32,17 @@ if not IS_WINDOWS:
 
             execv.side_effect = Exception('foo')
             r = detach('/bin/boo', ['a', 'b', 'c'],
-                       logfile='/var/log', pidfile='/var/pid')
+                       logfile='/var/log', pidfile='/var/pid', app=self.app)
             context.__enter__.assert_called_with()
             self.assertTrue(logger.critical.called)
             setup_logs.assert_called_with('ERROR', '/var/log')
             self.assertEqual(r, 1)
 
 
-class test_PartialOptionParser(Case):
+class test_PartialOptionParser(AppCase):
 
     def test_parser(self):
-        x = detached_celeryd()
+        x = detached_celeryd(self.app)
         p = x.Parser('celeryd_detach')
         options, values = p.parse_args(['--logfile=foo', '--fake', '--enable',
                                         'a', 'b', '-c1', '-d', '2'])
@@ -64,13 +64,13 @@ class test_PartialOptionParser(Case):
         p.get_option('--logfile').nargs = 1
 
 
-class test_Command(Case):
+class test_Command(AppCase):
     argv = ['--autoscale=10,2', '-c', '1',
             '--logfile=/var/log', '-lDEBUG',
             '--', '.disable_rate_limits=1']
 
     def test_parse_options(self):
-        x = detached_celeryd()
+        x = detached_celeryd(app=self.app)
         o, v, l = x.parse_options('cd', self.argv)
         self.assertEqual(o.logfile, '/var/log')
         self.assertEqual(l, ['--autoscale=10,2', '-c', '1',
@@ -81,7 +81,7 @@ class test_Command(Case):
     @patch('sys.exit')
     @patch('celery.bin.celeryd_detach.detach')
     def test_execute_from_commandline(self, detach, exit):
-        x = detached_celeryd()
+        x = detached_celeryd(app=self.app)
         x.execute_from_commandline(self.argv)
         self.assertTrue(exit.called)
         detach.assert_called_with(
@@ -92,10 +92,11 @@ class test_Command(Case):
                 '--logfile=/var/log', '--pidfile=celeryd.pid',
                 '--', '.disable_rate_limits=1'
             ],
+            app=self.app,
         )
 
     @patch('celery.bin.celeryd_detach.detached_celeryd')
     def test_main(self, command):
         c = command.return_value = Mock()
-        main()
+        main(self.app)
         c.execute_from_commandline.assert_called_with()

+ 4 - 4
celery/tests/bin/test_celeryevdump.py

@@ -9,12 +9,12 @@ from celery.events.dumper import (
     evdump,
 )
 
-from celery.tests.case import Case, WhateverIO
+from celery.tests.case import AppCase, WhateverIO
 
 
-class test_Dumper(Case):
+class test_Dumper(AppCase):
 
-    def setUp(self):
+    def setup(self):
         self.out = WhateverIO()
         self.dumper = Dumper(out=self.out)
 
@@ -44,7 +44,7 @@ class test_Dumper(Case):
     @patch('celery.events.EventReceiver.capture')
     def test_evdump(self, capture):
         capture.side_effect = KeyboardInterrupt()
-        evdump()
+        evdump(app=self.app)
 
     def test_evdump_error_handler(self):
         app = Mock(name='app')

+ 6 - 6
celery/tests/bin/test_multi.py

@@ -19,10 +19,10 @@ from celery.bin.multi import (
     __doc__ as doc,
 )
 
-from celery.tests.case import Case, WhateverIO
+from celery.tests.case import AppCase, WhateverIO
 
 
-class test_functions(Case):
+class test_functions(AppCase):
 
     def test_findsig(self):
         self.assertEqual(findsig(['a', 'b', 'c', '-1']), 1)
@@ -57,7 +57,7 @@ class test_functions(Case):
         self.assertEqual(quote("the 'quick"), "'the '\\''quick'")
 
 
-class test_NamespacedOptionParser(Case):
+class test_NamespacedOptionParser(AppCase):
 
     def test_parse(self):
         x = NamespacedOptionParser(['-c:1,3', '4'])
@@ -76,7 +76,7 @@ class test_NamespacedOptionParser(Case):
         self.assertEqual(x.passthrough, '-- .disable_rate_limits=1')
 
 
-class test_multi_args(Case):
+class test_multi_args(AppCase):
 
     @patch('socket.gethostname')
     def test_parse(self, gethostname):
@@ -160,9 +160,9 @@ class test_multi_args(Case):
         )
 
 
-class test_MultiTool(Case):
+class test_MultiTool(AppCase):
 
-    def setUp(self):
+    def setup(self):
         self.fh = WhateverIO()
         self.env = {}
         self.t = MultiTool(env=self.env, fh=self.fh)

+ 53 - 74
celery/tests/bin/test_worker.py

@@ -12,7 +12,6 @@ from nose import SkipTest
 from billiard import current_process
 from kombu import Exchange, Queue
 
-from celery import Celery
 from celery import platforms
 from celery import signals
 from celery.app import trace
@@ -68,33 +67,27 @@ class Worker(cd.Worker):
 class test_Worker(WorkerAppCase):
     Worker = Worker
 
-    def teardown(self):
-        self.app.conf.CELERY_INCLUDE = ()
-
     @disable_stdouts
     def test_queues_string(self):
-        celery = Celery(set_as_current=False)
-        w = celery.Worker()
+        w = self.app.Worker()
         w.setup_queues('foo,bar,baz')
         self.assertEqual(w.queues, ['foo', 'bar', 'baz'])
-        self.assertTrue('foo' in celery.amqp.queues)
+        self.assertTrue('foo' in self.app.amqp.queues)
 
     @disable_stdouts
     def test_cpu_count(self):
-        celery = Celery(set_as_current=False)
         with patch('celery.worker.cpu_count') as cpu_count:
             cpu_count.side_effect = NotImplementedError()
-            w = celery.Worker(concurrency=None)
+            w = self.app.Worker(concurrency=None)
             self.assertEqual(w.concurrency, 2)
-        w = celery.Worker(concurrency=5)
+        w = self.app.Worker(concurrency=5)
         self.assertEqual(w.concurrency, 5)
 
     @disable_stdouts
     def test_windows_B_option(self):
-        celery = Celery(set_as_current=False)
-        celery.IS_WINDOWS = True
+        self.app.IS_WINDOWS = True
         with self.assertRaises(SystemExit):
-            worker(app=celery).run(beat=True)
+            worker(app=self.app).run(beat=True)
 
     def test_setup_concurrency_very_early(self):
         x = worker()
@@ -124,26 +117,23 @@ class test_Worker(WorkerAppCase):
 
     @disable_stdouts
     def test_invalid_loglevel_gives_error(self):
-        x = worker(app=Celery(set_as_current=False))
+        x = worker(app=self.app)
         with self.assertRaises(SystemExit):
             x.run(loglevel='GRIM_REAPER')
 
     def test_no_loglevel(self):
-        app = Celery(set_as_current=False)
-        app.Worker = Mock()
-        worker(app=app).run(loglevel=None)
+        self.app.Worker = Mock()
+        worker(app=self.app).run(loglevel=None)
 
     def test_tasklist(self):
-        celery = Celery(set_as_current=False)
-        worker = celery.Worker()
+        worker = self.app.Worker()
         self.assertTrue(worker.app.tasks)
         self.assertTrue(worker.app.finalized)
         self.assertTrue(worker.tasklist(include_builtins=True))
         worker.tasklist(include_builtins=False)
 
     def test_extra_info(self):
-        celery = Celery(set_as_current=False)
-        worker = celery.Worker()
+        worker = self.app.Worker()
         worker.loglevel = logging.WARNING
         self.assertFalse(worker.extra_info())
         worker.loglevel = logging.INFO
@@ -154,6 +144,7 @@ class test_Worker(WorkerAppCase):
         worker = self.Worker(app=self.app, loglevel='INFO')
         self.assertEqual(worker.loglevel, logging.INFO)
 
+    @disable_stdouts
     def test_run_worker(self):
         handlers = {}
 
@@ -193,29 +184,21 @@ class test_Worker(WorkerAppCase):
         worker.autoscale = 13, 10
         self.assertTrue(worker.startup_info())
 
-        app = Celery(set_as_current=False)
-        worker = self.Worker(app=app, queues='foo,bar,baz,xuzzy,do,re,mi')
-        prev, app.loader = app.loader, Mock()
-        try:
-            app.loader.__module__ = 'acme.baked_beans'
-            self.assertTrue(worker.startup_info())
-        finally:
-            app.loader = prev
+        prev_loader = self.app.loader
+        worker = self.Worker(app=self.app, queues='foo,bar,baz,xuzzy,do,re,mi')
+        self.app.loader = Mock()
+        self.app.loader.__module__ = 'acme.baked_beans'
+        self.assertTrue(worker.startup_info())
 
-        prev, app.loader = app.loader, Mock()
-        try:
-            app.loader.__module__ = 'celery.loaders.foo'
-            self.assertTrue(worker.startup_info())
-        finally:
-            app.loader = prev
+        self.app.loader = Mock()
+        self.app.loader.__module__ = 'celery.loaders.foo'
+        self.assertTrue(worker.startup_info())
 
         from celery.loaders.app import AppLoader
-        prev, app.loader = app.loader, AppLoader(app=self.app)
-        try:
-            self.assertTrue(worker.startup_info())
-        finally:
-            app.loader = prev
+        self.app.loader = AppLoader(app=self.app)
+        self.assertTrue(worker.startup_info())
 
+        self.app.loader = prev_loader
         worker.send_events = True
         self.assertTrue(worker.startup_info())
 
@@ -239,32 +222,32 @@ class test_Worker(WorkerAppCase):
     def test_init_queues(self):
         app = self.app
         c = app.conf
-        p, app.amqp.queues = app.amqp.queues, app.amqp.Queues({
+        app.amqp.queues = app.amqp.Queues({
             'celery': {'exchange': 'celery',
                        'routing_key': 'celery'},
             'video': {'exchange': 'video',
-                      'routing_key': 'video'}})
-        try:
-            worker = self.Worker(app=self.app)
-            worker.setup_queues(['video'])
-            self.assertIn('video', app.amqp.queues)
-            self.assertIn('video', app.amqp.queues.consume_from)
-            self.assertIn('celery', app.amqp.queues)
-            self.assertNotIn('celery', app.amqp.queues.consume_from)
-
-            c.CELERY_CREATE_MISSING_QUEUES = False
-            del(app.amqp.queues)
-            with self.assertRaises(ImproperlyConfigured):
-                self.Worker(app=self.app).setup_queues(['image'])
-            del(app.amqp.queues)
-            c.CELERY_CREATE_MISSING_QUEUES = True
-            worker = self.Worker(app=self.app)
-            worker.setup_queues(queues=['image'])
-            self.assertIn('image', app.amqp.queues.consume_from)
-            self.assertEqual(Queue('image', Exchange('image'),
-                             routing_key='image'), app.amqp.queues['image'])
-        finally:
-            app.amqp.queues = p
+                      'routing_key': 'video'},
+        })
+        worker = self.Worker(app=self.app)
+        worker.setup_queues(['video'])
+        self.assertIn('video', app.amqp.queues)
+        self.assertIn('video', app.amqp.queues.consume_from)
+        self.assertIn('celery', app.amqp.queues)
+        self.assertNotIn('celery', app.amqp.queues.consume_from)
+
+        c.CELERY_CREATE_MISSING_QUEUES = False
+        del(app.amqp.queues)
+        with self.assertRaises(ImproperlyConfigured):
+            self.Worker(app=self.app).setup_queues(['image'])
+        del(app.amqp.queues)
+        c.CELERY_CREATE_MISSING_QUEUES = True
+        worker = self.Worker(app=self.app)
+        worker.setup_queues(queues=['image'])
+        self.assertIn('image', app.amqp.queues.consume_from)
+        self.assertEqual(
+            Queue('image', Exchange('image'), routing_key='image'),
+            app.amqp.queues['image'],
+        )
 
     @disable_stdouts
     def test_autoscale_argument(self):
@@ -272,6 +255,7 @@ class test_Worker(WorkerAppCase):
         self.assertListEqual(worker1.autoscale, [10, 3])
         worker2 = self.Worker(app=self.app, autoscale='10')
         self.assertListEqual(worker2.autoscale, [10, 0])
+        self.assert_no_logging_side_effect()
 
     def test_include_argument(self):
         worker1 = self.Worker(app=self.app, include='some.module')
@@ -318,16 +302,11 @@ class test_Worker(WorkerAppCase):
 
     @disable_stdouts
     def test_on_start_custom_logging(self):
-        prev, self.app.log.redirect_stdouts = (
-            self.app.log.redirect_stdouts, Mock(),
-        )
-        try:
-            worker = self.Worker(app=self.app, redirect_stoutds=True)
-            worker._custom_logging = True
-            worker.on_start()
-            self.assertFalse(self.app.log.redirect_stdouts.called)
-        finally:
-            self.app.log.redirect_stdouts = prev
+        self.app.log.redirect_stdouts = Mock()
+        worker = self.Worker(app=self.app, redirect_stoutds=True)
+        worker._custom_logging = True
+        worker.on_start()
+        self.assertFalse(self.app.log.redirect_stdouts.called)
 
     def test_setup_logging_no_color(self):
         worker = self.Worker(
@@ -463,7 +442,7 @@ class test_funs(WorkerAppCase):
         p, cd.Worker = cd.Worker, Worker
         s, sys.argv = sys.argv, ['worker', '--discard']
         try:
-            worker_main()
+            worker_main(app=self.app)
         finally:
             cd.Worker = p
             sys.argv = s

+ 128 - 47
celery/tests/case.py

@@ -9,6 +9,7 @@ except AttributeError:
     from unittest2.util import safe_repr, unorderable_list_difference  # noqa
 
 import importlib
+import inspect
 import logging
 import os
 import platform
@@ -18,6 +19,7 @@ import time
 import warnings
 
 from contextlib import contextmanager
+from copy import deepcopy
 from datetime import datetime, timedelta
 from functools import partial, wraps
 from types import ModuleType
@@ -27,28 +29,93 @@ try:
 except ImportError:
     import mock  # noqa
 from nose import SkipTest
+from kombu import Queue
 from kombu.log import NullHandler
-from kombu.utils import nested
+from kombu.utils import nested, symbol_by_name
 
+from celery import Celery
+from celery.app import current_app
+from celery.backends.cache import CacheBackend, DummyClient
 from celery.five import (
     WhateverIO, builtins, items, reraise,
     string_t, values, open_fqdn,
 )
 from celery.utils.functional import noop
+from celery.utils.imports import qualname
 
 __all__ = [
     'Case', 'AppCase', 'Mock', 'patch', 'call', 'skip_unless_module',
-    'wrap_logger', 'eager_tasks', 'with_environ', 'sleepdeprived',
+    'wrap_logger', 'with_environ', 'sleepdeprived',
     'skip_if_environ', 'skip_if_quick', 'todo', 'skip', 'skip_if',
     'skip_unless', 'mask_modules', 'override_stdouts', 'mock_module',
     'replace_module_value', 'sys_platform', 'reset_modules',
     'patch_modules', 'mock_context', 'mock_open', 'patch_many',
-    'patch_settings', 'assert_signal_called', 'skip_if_pypy',
+    'assert_signal_called', 'skip_if_pypy',
     'skip_if_jython', 'body_from_sig', 'restore_logging',
 ]
 patch = mock.patch
 call = mock.call
 
+CASE_REDEFINES_SETUP = """\
+{name} (subclass of AppCase) redefines private "setUp", should be: "setup"\
+"""
+CASE_REDEFINES_TEARDOWN = """\
+{name} (subclass of AppCase) redefines private "tearDown", \
+should be: "teardown"\
+"""
+
+CELERY_TEST_CONFIG = {
+    #: Don't want log output when running suite.
+    'CELERYD_HIJACK_ROOT_LOGGER': False,
+    'CELERY_SEND_TASK_ERROR_EMAILS': False,
+    'CELERY_DEFAULT_QUEUE': 'testcelery',
+    'CELERY_DEFAULT_EXCHANGE': 'testcelery',
+    'CELERY_DEFAULT_ROUTING_KEY': 'testcelery',
+    'CELERY_QUEUES': (
+        Queue('testcelery', routing_key='testcelery'),
+    ),
+    'CELERY_ENABLE_UTC': True,
+    'CELERY_TIMEZONE': 'UTC',
+    'CELERYD_LOG_COLOR': False,
+
+    # Mongo results tests (only executed if installed and running)
+    'CELERY_MONGODB_BACKEND_SETTINGS': {
+        'host': os.environ.get('MONGO_HOST') or 'localhost',
+        'port': os.environ.get('MONGO_PORT') or 27017,
+        'database': os.environ.get('MONGO_DB') or 'celery_unittests',
+        'taskmeta_collection': (os.environ.get('MONGO_TASKMETA_COLLECTION')
+                                or 'taskmeta_collection'),
+        'user': os.environ.get('MONGO_USER'),
+        'password': os.environ.get('MONGO_PASSWORD'),
+    }
+}
+
+
+class Trap(object):
+
+    def __getattr__(self, name):
+        raise RuntimeError('Test depends on current_app')
+
+
+class UnitLogging(symbol_by_name(Celery.log_cls)):
+
+    def __init__(self, *args, **kwargs):
+        super(UnitLogging, self).__init__(*args, **kwargs)
+        self.already_setup = True
+
+
+def UnitApp(name=None, broker=None, backend=None,
+            set_as_current=False, log=UnitLogging, **kwargs):
+
+    app = Celery(name or 'celery.tests',
+                 broker=broker or 'memory://',
+                 backend=backend or 'cache+memory://',
+                 set_as_current=set_as_current,
+                 log=log,
+                 **kwargs)
+    app.add_defaults(deepcopy(CELERY_TEST_CONFIG))
+    return app
+
 
 class Mock(mock.Mock):
 
@@ -204,29 +271,74 @@ class Case(unittest.TestCase):
             self.fail(self._formatMessage(msg, standardMsg))
 
 
+def depends_on_current_app(fun):
+    if inspect.isclass(fun):
+        fun.contained = False
+    else:
+        @wraps(fun)
+        def __inner(self, *args, **kwargs):
+            self.app.set_current()
+            return fun(self, *args, **kwargs)
+        return __inner
+
+
 class AppCase(Case):
     contained = True
 
+    def __new__(cls, *args, **kwargs):
+        if cls.__dict__.get('setUp'):
+            raise RuntimeError(CASE_REDEFINES_SETUP.format(name=qualname(cls)))
+        if cls.__dict__.get('tearDown'):
+            raise RuntimeError(CASE_REDEFINES_TEARDOWN.format(
+                name=qualname(cls)),
+            )
+        return super(AppCase, cls).__new__(cls, *args, **kwargs)
+
+    def Celery(self, *args, **kwargs):
+        return UnitApp(*args, **kwargs)
+
     def setUp(self):
-        from celery import Celery
-        from celery.app import current_app
-        from celery.backends.cache import CacheBackend, DummyClient
+        from celery import _state
         self._current_app = current_app()
-        app = self.app = (Celery(set_as_current=False)
-                          if self.contained else self._current_app)
-        if isinstance(app.backend, CacheBackend):
-            if isinstance(app.backend.client, DummyClient):
-                app.backend.client.cache.clear()
-        app.backend._cache.clear()
+        self._default_app = _state.default_app
+        trap = Trap()
+        _state.set_default_app(trap)
+        _state._tls.current_app = trap
+
+        self.app = self.Celery(set_as_current=False)
+        if not self.contained:
+            self.app.set_current()
         root = logging.getLogger()
         self.__rootlevel = root.level
         self.__roothandlers = root.handlers
-        self.setup()
+        try:
+            self.setup()
+        except:
+            self._teardown_app()
+            raise
+
+    def _teardown_app(self):
+        backend = self.app.__dict__.get('backend')
+        if backend is not None:
+            if isinstance(backend, CacheBackend):
+                if isinstance(backend.client, DummyClient):
+                    backend.client.cache.clear()
+                backend._cache.clear()
+        from celery._state import _tls, set_default_app
+        set_default_app(self._default_app)
+        _tls.current_app = self._current_app
+        if self.app is not self._current_app:
+            self.app.close()
+        self.app = None
 
     def tearDown(self):
-        self.teardown()
-        self._current_app.set_current()
+        try:
+            self.teardown()
+        finally:
+            self._teardown_app()
+        self.assert_no_logging_side_effect()
 
+    def assert_no_logging_side_effect(self):
         root = logging.getLogger()
         this = '.'.join([self.__class__.__name__, self._testMethodName])
         if root.level != self.__rootlevel:
@@ -258,16 +370,6 @@ def wrap_logger(logger, loglevel=logging.ERROR):
         logger.handlers = old_handlers
 
 
-@contextmanager
-def eager_tasks(app):
-    prev = app.conf.CELERY_ALWAYS_EAGER
-    app.conf.CELERY_ALWAYS_EAGER = True
-    try:
-        yield True
-    finally:
-        app.conf.CELERY_ALWAYS_EAGER = prev
-
-
 def with_environ(env_name, env_value):
 
     def _envpatched(fun):
@@ -279,8 +381,7 @@ def with_environ(env_name, env_value):
             try:
                 return fun(*args, **kwargs)
             finally:
-                if prev_val is not None:
-                    os.environ[env_name] = prev_val
+                os.environ[env_name] = prev_val or ''
 
         return _patch_environ
     return _envpatched
@@ -554,26 +655,6 @@ def patch_many(*targets):
     return nested(*[patch(target) for target in targets])
 
 
-@contextmanager
-def patch_settings(app, **config):
-    if app is None:
-        from celery import current_app
-        app = current_app
-    prev = {}
-    for key, value in items(config):
-        try:
-            prev[key] = getattr(app.conf, key)
-        except AttributeError:
-            pass
-        setattr(app.conf, key, value)
-
-    try:
-        yield app.conf
-    finally:
-        for key, value in items(prev):
-            setattr(app.conf, key, value)
-
-
 @contextmanager
 def assert_signal_called(signal, **expected):
     handler = Mock()

+ 63 - 10
celery/tests/compat_modules/test_compat.py

@@ -1,6 +1,15 @@
 from __future__ import absolute_import
 
-from celery.tests.case import AppCase
+from datetime import timedelta
+
+from celery.schedules import schedule
+from celery.task import (
+    periodic_task,
+    PeriodicTask
+)
+from celery.utils.timeutils import timedelta_seconds
+
+from celery.tests.case import AppCase, depends_on_current_app
 
 
 class test_Task(AppCase):
@@ -11,16 +20,60 @@ class test_Task(AppCase):
         class timkX(OldTask):
             abstract = True
 
-        app = Celery(set_as_current=False, accept_magic_kwargs=True)
-        timkX.bind(app)
-        # see #918
-        self.assertFalse(timkX.accept_magic_kwargs)
+        with self.Celery(set_as_current=False,
+                         accept_magic_kwargs=True) as app:
+            timkX.bind(app)
+            # see #918
+            self.assertFalse(timkX.accept_magic_kwargs)
 
-        from celery import Task as NewTask
+            from celery import Task as NewTask
 
-        class timkY(NewTask):
-            abstract = True
+            class timkY(NewTask):
+                abstract = True
+
+            timkY.bind(app)
+            self.assertFalse(timkY.accept_magic_kwargs)
+
+
+@depends_on_current_app
+class test_periodic_tasks(AppCase):
+
+    def setup(self):
+        @periodic_task(app=self.app, shared=False,
+                       run_every=schedule(timedelta(hours=1), app=self.app))
+        def my_periodic():
+            pass
+        self.my_periodic = my_periodic
+
+    def now(self):
+        return self.app.now()
+
+    def test_must_have_run_every(self):
+        with self.assertRaises(NotImplementedError):
+            type('Foo', (PeriodicTask, ), {'__module__': __name__})
+
+    def test_remaining_estimate(self):
+        s = self.my_periodic.run_every
+        self.assertIsInstance(
+            s.remaining_estimate(s.maybe_make_aware(self.now())),
+            timedelta)
+
+    def test_is_due_not_due(self):
+        due, remaining = self.my_periodic.run_every.is_due(self.now())
+        self.assertFalse(due)
+        # This assertion may fail if executed in the
+        # first minute of an hour, thus 59 instead of 60
+        self.assertGreater(remaining, 59)
 
-        timkY.bind(app)
-        self.assertFalse(timkY.accept_magic_kwargs)
+    def test_is_due(self):
+        p = self.my_periodic
+        due, remaining = p.run_every.is_due(
+            self.now() - p.run_every.run_every,
+        )
+        self.assertTrue(due)
+        self.assertEqual(remaining,
+                         timedelta_seconds(p.run_every.run_every))
 
+    def test_schedule_repr(self):
+        p = self.my_periodic
+        self.assertTrue(repr(p.run_every))

+ 6 - 7
celery/tests/compat_modules/test_compat_utils.py

@@ -1,14 +1,15 @@
 from __future__ import absolute_import
 
-
 import celery
+
 from celery.app.task import Task as ModernTask
 from celery.task.base import Task as CompatTask
 
-from celery.tests.case import Case
+from celery.tests.case import AppCase, depends_on_current_app
 
 
-class test_MagicModule(Case):
+@depends_on_current_app
+class test_MagicModule(AppCase):
 
     def test_class_property_set_without_type(self):
         self.assertTrue(ModernTask.__dict__['app'].__get__(CompatTask()))
@@ -21,10 +22,8 @@ class test_MagicModule(Case):
 
         class X(CompatTask):
             pass
-
-        app = celery.Celery(set_as_current=False)
-        ModernTask.__dict__['app'].__set__(X(), app)
-        self.assertEqual(X.app, app)
+        ModernTask.__dict__['app'].__set__(X(), self.app)
+        self.assertIs(X.app, self.app)
 
     def test_dir(self):
         self.assertTrue(dir(celery.messaging))

+ 4 - 3
celery/tests/compat_modules/test_decorators.py

@@ -4,21 +4,22 @@ import warnings
 
 from celery.task import base
 
-from celery.tests.case import Case
+from celery.tests.case import AppCase, depends_on_current_app
 
 
 def add(x, y):
     return x + y
 
 
-class test_decorators(Case):
+@depends_on_current_app
+class test_decorators(AppCase):
 
     def test_task_alias(self):
         from celery import task
         self.assertTrue(task.__file__)
         self.assertTrue(task(add))
 
-    def setUp(self):
+    def setup(self):
         with warnings.catch_warnings(record=True):
             from celery import decorators
             self.decorators = decorators

+ 13 - 10
celery/tests/compat_modules/test_http.py

@@ -13,7 +13,7 @@ from kombu.utils.encoding import from_utf8
 
 from celery.five import StringIO, items
 from celery.task import http
-from celery.tests.case import AppCase, Case, eager_tasks
+from celery.tests.case import AppCase, Case
 
 
 @contextmanager
@@ -140,16 +140,19 @@ class test_HttpDispatch(AppCase):
 
 
 class test_URL(AppCase):
-    contained = False
 
     def test_URL_get_async(self):
-        with eager_tasks(self.app):
-            with mock_urlopen(success_response(100)):
-                d = http.URL('http://example.com/mul').get_async(x=10, y=10)
-                self.assertEqual(d.get(), 100)
+        self.app.conf.CELERY_ALWAYS_EAGER = True
+        with mock_urlopen(success_response(100)):
+            d = http.URL(
+                'http://example.com/mul', app=self.app,
+            ).get_async(x=10, y=10)
+            self.assertEqual(d.get(), 100)
 
     def test_URL_post_async(self):
-        with eager_tasks(self.app):
-            with mock_urlopen(success_response(100)):
-                d = http.URL('http://example.com/mul').post_async(x=10, y=10)
-                self.assertEqual(d.get(), 100)
+        self.app.conf.CELERY_ALWAYS_EAGER = True
+        with mock_urlopen(success_response(100)):
+            d = http.URL(
+                'http://example.com/mul', app=self.app,
+            ).post_async(x=10, y=10)
+            self.assertEqual(d.get(), 100)

+ 3 - 2
celery/tests/compat_modules/test_messaging.py

@@ -1,10 +1,11 @@
 from __future__ import absolute_import
 
 from celery import messaging
-from celery.tests.case import Case
+from celery.tests.case import AppCase, depends_on_current_app
 
 
-class test_compat_messaging_module(Case):
+@depends_on_current_app
+class test_compat_messaging_module(AppCase):
 
     def test_get_consume_set(self):
         conn = messaging.establish_connection()

+ 88 - 42
celery/tests/compat_modules/test_sets.py

@@ -1,44 +1,92 @@
 from __future__ import absolute_import
 
 import anyjson
+import warnings
 
 from mock import Mock, patch
 
+from celery import uuid
+from celery.result import TaskSetResult
 from celery.task import Task
-from celery.task.sets import subtask, TaskSet
 from celery.canvas import Signature
 
+from celery.tests.tasks.test_result import make_mock_group
 from celery.tests.case import AppCase
 
 
-class MockTask(Task):
-    name = 'tasks.add'
+class SetsCase(AppCase):
 
-    def run(self, x, y, **kwargs):
-        return x + y
+    def setup(self):
+        with warnings.catch_warnings(record=True):
+            from celery.task import sets
+            self.sets = sets
+            self.subtask = sets.subtask
+            self.TaskSet = sets.TaskSet
 
-    @classmethod
-    def apply_async(cls, args, kwargs, **options):
-        return (args, kwargs, options)
+        class MockTask(Task):
+            app = self.app
+            name = 'tasks.add'
 
-    @classmethod
-    def apply(cls, args, kwargs, **options):
-        return (args, kwargs, options)
+            def run(self, x, y, **kwargs):
+                return x + y
 
+            @classmethod
+            def apply_async(cls, args, kwargs, **options):
+                return (args, kwargs, options)
 
-class test_subtask(AppCase):
+            @classmethod
+            def apply(cls, args, kwargs, **options):
+                return (args, kwargs, options)
+        self.MockTask = MockTask
+
+
+class test_TaskSetResult(AppCase):
+
+    def setup(self):
+        self.size = 10
+        self.ts = TaskSetResult(uuid(), make_mock_group(self.app, self.size))
+
+    def test_total(self):
+        self.assertEqual(self.ts.total, self.size)
+
+    def test_compat_properties(self):
+        self.assertEqual(self.ts.taskset_id, self.ts.id)
+        self.ts.taskset_id = 'foo'
+        self.assertEqual(self.ts.taskset_id, 'foo')
+
+    def test_compat_subtasks_kwarg(self):
+        x = TaskSetResult(uuid(), subtasks=[1, 2, 3])
+        self.assertEqual(x.results, [1, 2, 3])
+
+    def test_itersubtasks(self):
+        it = self.ts.itersubtasks()
+
+        for i, t in enumerate(it):
+            self.assertEqual(t.get(), i)
+
+
+class test_App(AppCase):
+
+    def test_TaskSet(self):
+        with warnings.catch_warnings(record=True):
+            ts = self.app.TaskSet()
+            self.assertListEqual(ts.tasks, [])
+            self.assertIs(ts.app, self.app)
+
+
+class test_subtask(SetsCase):
 
     def test_behaves_like_type(self):
-        s = subtask('tasks.add', (2, 2), {'cache': True},
-                    {'routing_key': 'CPU-bound'})
-        self.assertDictEqual(subtask(s), s)
+        s = self.subtask('tasks.add', (2, 2), {'cache': True},
+                         {'routing_key': 'CPU-bound'})
+        self.assertDictEqual(self.subtask(s), s)
 
     def test_task_argument_can_be_task_cls(self):
-        s = subtask(MockTask, (2, 2))
-        self.assertEqual(s.task, MockTask.name)
+        s = self.subtask(self.MockTask, (2, 2))
+        self.assertEqual(s.task, self.MockTask.name)
 
     def test_apply_async(self):
-        s = MockTask.subtask(
+        s = self.MockTask.subtask(
             (2, 2), {'cache': True}, {'routing_key': 'CPU-bound'},
         )
         args, kwargs, options = s.apply_async()
@@ -47,7 +95,7 @@ class test_subtask(AppCase):
         self.assertDictEqual(options, {'routing_key': 'CPU-bound'})
 
     def test_delay_argmerge(self):
-        s = MockTask.subtask(
+        s = self.MockTask.subtask(
             (2, ), {'cache': True}, {'routing_key': 'CPU-bound'},
         )
         args, kwargs, options = s.delay(10, cache=False, other='foo')
@@ -56,7 +104,7 @@ class test_subtask(AppCase):
         self.assertDictEqual(options, {'routing_key': 'CPU-bound'})
 
     def test_apply_async_argmerge(self):
-        s = MockTask.subtask(
+        s = self.MockTask.subtask(
             (2, ), {'cache': True}, {'routing_key': 'CPU-bound'},
         )
         args, kwargs, options = s.apply_async((10, ),
@@ -70,7 +118,7 @@ class test_subtask(AppCase):
                                        'exchange': 'fast'})
 
     def test_apply_argmerge(self):
-        s = MockTask.subtask(
+        s = self.MockTask.subtask(
             (2, ), {'cache': True}, {'routing_key': 'CPU-bound'},
         )
         args, kwargs, options = s.apply((10, ),
@@ -85,50 +133,48 @@ class test_subtask(AppCase):
         )
 
     def test_is_JSON_serializable(self):
-        s = MockTask.subtask(
+        s = self.MockTask.subtask(
             (2, ), {'cache': True}, {'routing_key': 'CPU-bound'},
         )
         s.args = list(s.args)                   # tuples are not preserved
                                                 # but this doesn't matter.
-        self.assertEqual(s, subtask(anyjson.loads(anyjson.dumps(s))))
+        self.assertEqual(s, self.subtask(anyjson.loads(anyjson.dumps(s))))
 
     def test_repr(self):
-        s = MockTask.subtask((2, ), {'cache': True})
+        s = self.MockTask.subtask((2, ), {'cache': True})
         self.assertIn('2', repr(s))
         self.assertIn('cache=True', repr(s))
 
     def test_reduce(self):
-        s = MockTask.subtask((2, ), {'cache': True})
+        s = self.MockTask.subtask((2, ), {'cache': True})
         cls, args = s.__reduce__()
         self.assertDictEqual(dict(cls(*args)), dict(s))
 
 
-class test_TaskSet(AppCase):
+class test_TaskSet(SetsCase):
 
     def test_task_arg_can_be_iterable__compat(self):
-        ts = TaskSet([MockTask.subtask((i, i))
-                      for i in (2, 4, 8)], app=self.app)
+        ts = self.TaskSet([self.MockTask.subtask((i, i))
+                           for i in (2, 4, 8)], app=self.app)
         self.assertEqual(len(ts), 3)
 
     def test_respects_ALWAYS_EAGER(self):
         app = self.app
 
-        class MockTaskSet(TaskSet):
+        class MockTaskSet(self.TaskSet):
             applied = 0
 
             def apply(self, *args, **kwargs):
                 self.applied += 1
 
         ts = MockTaskSet(
-            [MockTask.subtask((i, i)) for i in (2, 4, 8)],
+            [self.MockTask.subtask((i, i)) for i in (2, 4, 8)],
             app=self.app,
         )
         app.conf.CELERY_ALWAYS_EAGER = True
-        try:
-            ts.apply_async()
-        finally:
-            app.conf.CELERY_ALWAYS_EAGER = False
+        ts.apply_async()
         self.assertEqual(ts.applied, 1)
+        app.conf.CELERY_ALWAYS_EAGER = False
 
         with patch('celery.task.sets.get_current_worker_task') as gwt:
             parent = gwt.return_value = Mock()
@@ -143,8 +189,8 @@ class test_TaskSet(AppCase):
             def apply_async(self, *args, **kwargs):
                 applied[0] += 1
 
-        ts = TaskSet([mocksubtask(MockTask, (i, i))
-                      for i in (2, 4, 8)], app=self.app)
+        ts = self.TaskSet([mocksubtask(self.MockTask, (i, i))
+                           for i in (2, 4, 8)], app=self.app)
         ts.apply_async()
         self.assertEqual(applied[0], 3)
 
@@ -157,7 +203,7 @@ class test_TaskSet(AppCase):
 
         # setting current_task
 
-        @self.app.task
+        @self.app.task(shared=False)
         def xyz():
             pass
 
@@ -179,22 +225,22 @@ class test_TaskSet(AppCase):
             def apply(self, *args, **kwargs):
                 applied[0] += 1
 
-        ts = TaskSet([mocksubtask(MockTask, (i, i))
-                      for i in (2, 4, 8)], app=self.app)
+        ts = self.TaskSet([mocksubtask(self.MockTask, (i, i))
+                           for i in (2, 4, 8)], app=self.app)
         ts.apply()
         self.assertEqual(applied[0], 3)
 
     def test_set_app(self):
-        ts = TaskSet([], app=self.app)
+        ts = self.TaskSet([], app=self.app)
         ts.app = 42
         self.assertEqual(ts.app, 42)
 
     def test_set_tasks(self):
-        ts = TaskSet([], app=self.app)
+        ts = self.TaskSet([], app=self.app)
         ts.tasks = [1, 2, 3]
         self.assertEqual(ts, [1, 2, 3])
 
     def test_set_Publisher(self):
-        ts = TaskSet([], app=self.app)
+        ts = self.TaskSet([], app=self.app)
         ts.Publisher = 42
         self.assertEqual(ts.Publisher, 42)

+ 2 - 2
celery/tests/concurrency/test_concurrency.py

@@ -6,10 +6,10 @@ from itertools import count
 from mock import Mock
 
 from celery.concurrency.base import apply_target, BasePool
-from celery.tests.case import Case
+from celery.tests.case import AppCase
 
 
-class test_BasePool(Case):
+class test_BasePool(AppCase):
 
     def test_apply_target(self):
 

+ 4 - 4
celery/tests/concurrency/test_eventlet.py

@@ -14,13 +14,13 @@ from celery.concurrency.eventlet import (
     TaskPool,
 )
 
-from celery.tests.case import Case, mock_module, patch_many, skip_if_pypy
+from celery.tests.case import AppCase, mock_module, patch_many, skip_if_pypy
 
 
-class EventletCase(Case):
+class EventletCase(AppCase):
 
     @skip_if_pypy
-    def setUp(self):
+    def setup(self):
         if is_pypy:
             raise SkipTest('mock_modules not working on PyPy1.9')
         try:
@@ -30,7 +30,7 @@ class EventletCase(Case):
                 'eventlet not installed, skipping related tests.')
 
     @skip_if_pypy
-    def tearDown(self):
+    def teardown(self):
         for mod in [mod for mod in sys.modules if mod.startswith('eventlet')]:
             try:
                 del(sys.modules[mod])

+ 7 - 7
celery/tests/concurrency/test_gevent.py

@@ -14,7 +14,7 @@ from celery.concurrency.gevent import (
 )
 
 from celery.tests.case import (
-    Case, mock_module, patch_many, skip_if_pypy,
+    AppCase, mock_module, patch_many, skip_if_pypy,
 )
 
 gevent_modules = (
@@ -26,10 +26,10 @@ gevent_modules = (
 )
 
 
-class GeventCase(Case):
+class GeventCase(AppCase):
 
     @skip_if_pypy
-    def setUp(self):
+    def setup(self):
         try:
             self.gevent = __import__('gevent')
         except ImportError:
@@ -58,7 +58,7 @@ class test_gevent_patch(GeventCase):
                 monkey.patch_all = prev_monkey_patch
 
 
-class test_Schedule(Case):
+class test_Schedule(AppCase):
 
     def test_sched(self):
         with mock_module(*gevent_modules):
@@ -88,7 +88,7 @@ class test_Schedule(Case):
                 g.cancel()
 
 
-class test_TasKPool(Case):
+class test_TaskPool(AppCase):
 
     def test_pool(self):
         with mock_module(*gevent_modules):
@@ -115,7 +115,7 @@ class test_TasKPool(Case):
                 self.assertEqual(x.num_processes, 3)
 
 
-class test_Timer(Case):
+class test_Timer(AppCase):
 
     def test_timer(self):
         with mock_module(*gevent_modules):
@@ -127,7 +127,7 @@ class test_Timer(Case):
             x.schedule.clear.assert_called_with()
 
 
-class test_apply_timeout(Case):
+class test_apply_timeout(AppCase):
 
     def test_apply_timeout(self):
 

+ 3 - 3
celery/tests/concurrency/test_pool.py

@@ -7,7 +7,7 @@ from nose import SkipTest
 
 from billiard.einfo import ExceptionInfo
 
-from celery.tests.case import Case
+from celery.tests.case import AppCase
 
 
 def do_something(i):
@@ -25,9 +25,9 @@ def raise_something(i):
         return ExceptionInfo()
 
 
-class test_TaskPool(Case):
+class test_TaskPool(AppCase):
 
-    def setUp(self):
+    def setup(self):
         try:
             __import__('multiprocessing')
         except ImportError:

+ 2 - 2
celery/tests/concurrency/test_solo.py

@@ -4,10 +4,10 @@ import operator
 
 from celery.concurrency import solo
 from celery.utils.functional import noop
-from celery.tests.case import Case
+from celery.tests.case import AppCase
 
 
-class test_solo_TaskPool(Case):
+class test_solo_TaskPool(AppCase):
 
     def test_on_start(self):
         x = solo.TaskPool()

+ 2 - 2
celery/tests/concurrency/test_threads.py

@@ -4,7 +4,7 @@ from mock import Mock
 
 from celery.concurrency.threads import NullDict, TaskPool, apply_target
 
-from celery.tests.case import Case, mask_modules, mock_module
+from celery.tests.case import AppCase, Case, mask_modules, mock_module
 
 
 class test_NullDict(Case):
@@ -16,7 +16,7 @@ class test_NullDict(Case):
             x['foo']
 
 
-class test_TaskPool(Case):
+class test_TaskPool(AppCase):
 
     def test_without_threadpool(self):
 

+ 0 - 44
celery/tests/config.py

@@ -1,44 +0,0 @@
-from __future__ import absolute_import
-
-import os
-
-from kombu import Queue
-
-BROKER_URL = 'memory://'
-
-#: warn if config module not found
-os.environ['C_WNOCONF'] = 'yes'
-
-#: Don't want log output when running suite.
-CELERYD_HIJACK_ROOT_LOGGER = False
-
-CELERY_RESULT_BACKEND = 'cache'
-CELERY_CACHE_BACKEND = 'memory'
-CELERY_RESULT_DBURI = 'sqlite:///test.db'
-CELERY_SEND_TASK_ERROR_EMAILS = False
-
-CELERY_DEFAULT_QUEUE = 'testcelery'
-CELERY_DEFAULT_EXCHANGE = 'testcelery'
-CELERY_DEFAULT_ROUTING_KEY = 'testcelery'
-CELERY_QUEUES = (
-    Queue('testcelery', routing_key='testcelery'),
-)
-
-CELERY_ENABLE_UTC = True
-CELERY_TIMEZONE = 'UTC'
-
-CELERYD_LOG_COLOR = False
-
-# Mongo results tests (only executed if installed and running)
-CELERY_MONGODB_BACKEND_SETTINGS = {
-    'host': os.environ.get('MONGO_HOST') or 'localhost',
-    'port': os.environ.get('MONGO_PORT') or 27017,
-    'database': os.environ.get('MONGO_DB') or 'celery_unittests',
-    'taskmeta_collection': (os.environ.get('MONGO_TASKMETA_COLLECTION')
-                            or 'taskmeta_collection'),
-}
-if os.environ.get('MONGO_USER'):
-    CELERY_MONGODB_BACKEND_SETTINGS['user'] = os.environ.get('MONGO_USER')
-if os.environ.get('MONGO_PASSWORD'):
-    CELERY_MONGODB_BACKEND_SETTINGS['password'] = \
-        os.environ.get('MONGO_PASSWORD')

+ 24 - 26
celery/tests/contrib/test_abortable.py

@@ -1,51 +1,49 @@
 from __future__ import absolute_import
 
 from celery.contrib.abortable import AbortableTask, AbortableAsyncResult
-from celery.result import AsyncResult
-from celery.tests.case import Case
+from celery.tests.case import AppCase
 
 
-class MyAbortableTask(AbortableTask):
+class test_AbortableTask(AppCase):
 
-    def run(self, **kwargs):
-        return True
+    def setup(self):
 
-
-class test_AbortableTask(Case):
+        @self.app.task(base=AbortableTask, shared=False)
+        def abortable():
+            return True
+        self.abortable = abortable
 
     def test_async_result_is_abortable(self):
-        t = MyAbortableTask()
-        result = t.apply_async()
+        result = self.abortable.apply_async()
         tid = result.id
-        self.assertIsInstance(t.AsyncResult(tid), AbortableAsyncResult)
+        self.assertIsInstance(
+            self.abortable.AsyncResult(tid), AbortableAsyncResult,
+        )
 
     def test_is_not_aborted(self):
-        t = MyAbortableTask()
-        t.push_request()
+        self.abortable.push_request()
         try:
-            result = t.apply_async()
+            result = self.abortable.apply_async()
             tid = result.id
-            self.assertFalse(t.is_aborted(task_id=tid))
+            self.assertFalse(self.abortable.is_aborted(task_id=tid))
         finally:
-            t.pop_request()
+            self.abortable.pop_request()
 
     def test_is_aborted_not_abort_result(self):
-        t = MyAbortableTask()
-        t.AsyncResult = AsyncResult
-        t.push_request()
+        self.abortable.AsyncResult = self.app.AsyncResult
+        self.abortable.push_request()
         try:
-            t.request.id = 'foo'
-            self.assertFalse(t.is_aborted())
+            self.abortable.request.id = 'foo'
+            self.assertFalse(self.abortable.is_aborted())
         finally:
-            t.pop_request()
+            self.abortable.pop_request()
 
     def test_abort_yields_aborted(self):
-        t = MyAbortableTask()
-        t.push_request()
+        self.abortable.push_request()
         try:
-            result = t.apply_async()
+            result = self.abortable.apply_async()
             result.abort()
             tid = result.id
-            self.assertTrue(t.is_aborted(task_id=tid))
+            self.assertTrue(self.abortable.is_aborted(task_id=tid))
         finally:
-            t.pop_request()
+            self.abortable.pop_request()

+ 1 - 1
celery/tests/contrib/test_methods.py

@@ -14,7 +14,7 @@ class test_task_method(AppCase):
             def __init__(self):
                 self.state = 0
 
-            @self.app.task(filter=task_method)
+            @self.app.task(shared=False, filter=task_method)
             def add(self, x):
                 self.state += x
 

+ 5 - 5
celery/tests/contrib/test_migrate.py

@@ -26,7 +26,7 @@ from celery.contrib.migrate import (
     move,
 )
 from celery.utils.encoding import bytes_t, ensure_bytes
-from celery.tests.case import AppCase, Case, Mock, override_stdouts
+from celery.tests.case import AppCase, Mock, override_stdouts
 
 # hack to ignore error at shutdown
 QoS.restore_at_shutdown = False
@@ -52,7 +52,7 @@ def Message(body, exchange='exchange', routing_key='rkey',
     )
 
 
-class test_State(Case):
+class test_State(AppCase):
 
     def test_strtotal(self):
         x = State()
@@ -178,7 +178,7 @@ class test_start_filter(AppCase):
             self.assertTrue(stop_filtering_raised)
 
 
-class test_filter_callback(Case):
+class test_filter_callback(AppCase):
 
     def test_filter(self):
         callback = Mock()
@@ -193,7 +193,7 @@ class test_filter_callback(Case):
         callback.assert_called_with(t1, message)
 
 
-class test_utils(Case):
+class test_utils(AppCase):
 
     def test_task_id_in(self):
         self.assertTrue(task_id_in(['A'], {'id': 'A'}, Mock()))
@@ -243,7 +243,7 @@ class test_utils(Case):
             )
 
 
-class test_migrate_task(Case):
+class test_migrate_task(AppCase):
 
     def test_removes_compression_header(self):
         x = Message('foo', compression='zlib')

+ 10 - 13
celery/tests/events/test_events.py

@@ -4,7 +4,6 @@ import socket
 
 from mock import Mock
 
-from celery import Celery
 from celery.events import Event
 from celery.tests.case import AppCase
 
@@ -41,21 +40,19 @@ class test_Event(AppCase):
 class test_EventDispatcher(AppCase):
 
     def test_redis_uses_fanout_exchange(self):
-        with Celery(set_as_current=False) as app:
-            app.connection = Mock()
-            conn = app.connection.return_value = Mock()
-            conn.transport.driver_type = 'redis'
+        self.app.connection = Mock()
+        conn = self.app.connection.return_value = Mock()
+        conn.transport.driver_type = 'redis'
 
-            dispatcher = app.events.Dispatcher(conn, enabled=False)
-            self.assertEqual(dispatcher.exchange.type, 'fanout')
+        dispatcher = self.app.events.Dispatcher(conn, enabled=False)
+        self.assertEqual(dispatcher.exchange.type, 'fanout')
 
     def test_others_use_topic_exchange(self):
-        with Celery(set_as_current=False) as app:
-            app.connection = Mock()
-            conn = app.connection.return_value = Mock()
-            conn.transport.driver_type = 'amqp'
-            dispatcher = app.events.Dispatcher(conn, enabled=False)
-            self.assertEqual(dispatcher.exchange.type, 'topic')
+        self.app.connection = Mock()
+        conn = self.app.connection.return_value = Mock()
+        conn.transport.driver_type = 'amqp'
+        dispatcher = self.app.events.Dispatcher(conn, enabled=False)
+        self.assertEqual(dispatcher.exchange.type, 'topic')
 
     def test_takes_channel_connection(self):
         x = self.app.events.Dispatcher(channel=Mock())

+ 4 - 4
celery/tests/events/test_state.py

@@ -18,7 +18,7 @@ from celery.events.state import (
 )
 from celery.five import range
 from celery.utils import uuid
-from celery.tests.case import Case
+from celery.tests.case import AppCase
 
 
 class replay(object):
@@ -152,7 +152,7 @@ class ev_snapshot(replay):
                                uuid=uuid(), hostname=worker))
 
 
-class test_Worker(Case):
+class test_Worker(AppCase):
 
     def test_equality(self):
         self.assertEqual(Worker(hostname='foo').hostname, 'foo')
@@ -192,7 +192,7 @@ class test_Worker(Case):
         self.assertEqual(len(worker.heartbeats), 1)
 
 
-class test_Task(Case):
+class test_Task(AppCase):
 
     def test_equality(self):
         self.assertEqual(Task(uuid='foo').uuid, 'foo')
@@ -265,7 +265,7 @@ class test_Task(Case):
         self.assertTrue(repr(Task(uuid='xxx', name='tasks.add')))
 
 
-class test_State(Case):
+class test_State(AppCase):
 
     def test_repr(self):
         self.assertTrue(repr(State()))

+ 5 - 7
celery/tests/fixups/test_django.py

@@ -5,7 +5,6 @@ import os
 from contextlib import contextmanager
 from mock import Mock, patch
 
-from celery import Celery
 from celery.fixups.django import (
     _maybe_close_fd,
     fixup,
@@ -62,10 +61,9 @@ class test_DjangoFixup(AppCase):
             self.assertIsNone(DjangoFixup(self.app)._close_old_connections)
 
     def test_install(self):
-        app = Celery(set_as_current=False)
-        app.conf = {'CELERY_DB_REUSE_MAX': None}
-        app.loader = Mock()
-        with self.fixup_context(app) as (f, _, _):
+        self.app.conf = {'CELERY_DB_REUSE_MAX': None}
+        self.app.loader = Mock()
+        with self.fixup_context(self.app) as (f, _, _):
             with patch_many('os.getcwd', 'sys.path',
                             'celery.fixups.django.signals') as (cw, p, sigs):
                 cw.return_value = '/opt/vandelay'
@@ -80,8 +78,8 @@ class test_DjangoFixup(AppCase):
                 sigs.worker_process_init.connect.assert_called_with(
                     f.on_worker_process_init,
                 )
-                self.assertEqual(app.loader.now, f.now)
-                self.assertEqual(app.loader.mail_admins, f.mail_admins)
+                self.assertEqual(self.app.loader.now, f.now)
+                self.assertEqual(self.app.loader.mail_admins, f.mail_admins)
                 p.append.assert_called_with('/opt/vandelay')
 
     def test_now(self):

+ 13 - 9
celery/tests/functional/case.py

@@ -11,8 +11,9 @@ import traceback
 from itertools import count
 from time import time
 
+from celery import current_app
 from celery.exceptions import TimeoutError
-from celery.task.control import ping, flatten_reply, inspect
+from celery.app.control import flatten_reply
 from celery.utils.imports import qualname
 
 from celery.tests.case import Case
@@ -39,9 +40,10 @@ class Worker(object):
     worker_ids = count(1)
     _shutdown_called = False
 
-    def __init__(self, hostname, loglevel='error'):
+    def __init__(self, hostname, loglevel='error', app=None):
         self.hostname = hostname
         self.loglevel = loglevel
+        self.app = app or current_app._get_current_object()
 
     def start(self):
         if not self.started:
@@ -51,16 +53,17 @@ class Worker(object):
     def _fork_and_exec(self):
         pid = os.fork()
         if pid == 0:
-            from celery import current_app
-            current_app.worker_main(['worker', '--loglevel=INFO',
-                                               '-n', self.hostname,
-                                               '-P', 'solo'])
+            self.app.worker_main(['worker', '--loglevel=INFO',
+                                  '-n', self.hostname,
+                                  '-P', 'solo'])
             os._exit(0)
         self.pid = pid
 
+    def ping(self, *args, **kwargs):
+        return self.app.control.ping(*args, **kwargs)
+
     def is_alive(self, timeout=1):
-        r = ping(destination=[self.hostname],
-                 timeout=timeout)
+        r = self.ping(destination=[self.hostname], timeout=timeout)
         return self.hostname in flatten_reply(r)
 
     def wait_until_started(self, timeout=10, interval=0.5):
@@ -124,7 +127,8 @@ class WorkerCase(Case):
         self.assertTrue(self.worker.is_alive)
 
     def inspect(self, timeout=1):
-        return inspect([self.worker.hostname], timeout=timeout)
+        return self.app.control.inspect([self.worker.hostname],
+                                        timeout=timeout)
 
     def my_response(self, response):
         return flatten_reply(response)[self.worker.hostname]

+ 25 - 37
celery/tests/security/test_security.py

@@ -32,7 +32,7 @@ from celery.tests.case import mock_open
 
 class test_security(SecurityCase):
 
-    def tearDown(self):
+    def teardown(self):
         registry._disabled_content_types.clear()
 
     def test_disable_insecure_serializers(self):
@@ -59,14 +59,10 @@ class test_security(SecurityCase):
         disabled = registry._disabled_content_types
         self.assertEqual(0, len(disabled))
 
-        prev, self.app.conf.CELERY_TASK_SERIALIZER = (
-            self.app.conf.CELERY_TASK_SERIALIZER, 'json')
-        try:
-            self.app.setup_security()
-            self.assertIn('application/x-python-serialize', disabled)
-            disabled.clear()
-        finally:
-            self.app.conf.CELERY_TASK_SERIALIZER = prev
+        self.app.conf.CELERY_TASK_SERIALIZER = 'json'
+        self.app.setup_security()
+        self.assertIn('application/x-python-serialize', disabled)
+        disabled.clear()
 
     @patch('celery.security.register_auth')
     @patch('celery.security._disable_insecure_serializers')
@@ -81,39 +77,31 @@ class test_security(SecurityCase):
             finally:
                 calls[0] += 1
 
-        prev, self.app.conf.CELERY_TASK_SERIALIZER = (
-            self.app.conf.CELERY_TASK_SERIALIZER, 'auth')
-        try:
-            with mock_open(side_effect=effect):
-                with patch('celery.security.registry') as registry:
-                    store = Mock()
-                    self.app.setup_security(['json'], key, cert, store)
-                    dis.assert_called_with(['json'])
-                    reg.assert_called_with('A', 'B', store, 'sha1', 'json')
-                    registry._set_default_serializer.assert_called_with('auth')
-        finally:
-            self.app.conf.CELERY_TASK_SERIALIZER = prev
+        self.app.conf.CELERY_TASK_SERIALIZER = 'auth'
+        with mock_open(side_effect=effect):
+            with patch('celery.security.registry') as registry:
+                store = Mock()
+                self.app.setup_security(['json'], key, cert, store)
+                dis.assert_called_with(['json'])
+                reg.assert_called_with('A', 'B', store, 'sha1', 'json')
+                registry._set_default_serializer.assert_called_with('auth')
 
     def test_security_conf(self):
-        prev, self.app.conf.CELERY_TASK_SERIALIZER = (
-            self.app.conf.CELERY_TASK_SERIALIZER, 'auth')
-        try:
-            with self.assertRaises(ImproperlyConfigured):
-                self.app.setup_security()
+        self.app.conf.CELERY_TASK_SERIALIZER = 'auth'
+        with self.assertRaises(ImproperlyConfigured):
+            self.app.setup_security()
 
-            _import = builtins.__import__
+        _import = builtins.__import__
 
-            def import_hook(name, *args, **kwargs):
-                if name == 'OpenSSL':
-                    raise ImportError
-                return _import(name, *args, **kwargs)
+        def import_hook(name, *args, **kwargs):
+            if name == 'OpenSSL':
+                raise ImportError
+            return _import(name, *args, **kwargs)
 
-            builtins.__import__ = import_hook
-            with self.assertRaises(ImproperlyConfigured):
-                self.app.setup_security()
-            builtins.__import__ = _import
-        finally:
-            self.app.conf.CELERY_TASK_SERIALIZER = prev
+        builtins.__import__ = import_hook
+        with self.assertRaises(ImproperlyConfigured):
+            self.app.setup_security()
+        builtins.__import__ = _import
 
     def test_reraise_errors(self):
         with self.assertRaises(SecurityError):

+ 13 - 20
celery/tests/tasks/test_canvas.py

@@ -140,20 +140,17 @@ class test_Signature(CanvasCase):
     def test_election(self):
         x = self.add.s(2, 2)
         x.freeze('foo')
-        prev, x.type.app.control = x.type.app.control, Mock()
-        try:
-            r = x.election()
-            self.assertTrue(x.type.app.control.election.called)
-            self.assertEqual(r.id, 'foo')
-        finally:
-            x.type.app.control = prev
-
-    def test_AsyncResult_when_not_registerd(self):
-        s = subtask('xxx.not.registered')
+        x.type.app.control = Mock()
+        r = x.election()
+        self.assertTrue(x.type.app.control.election.called)
+        self.assertEqual(r.id, 'foo')
+
+    def test_AsyncResult_when_not_registered(self):
+        s = subtask('xxx.not.registered', app=self.app)
         self.assertTrue(s.AsyncResult)
 
     def test_apply_async_when_not_registered(self):
-        s = subtask('xxx.not.registered')
+        s = subtask('xxx.not.registered', app=self.app)
         self.assertTrue(s._apply_async)
 
 
@@ -178,7 +175,9 @@ class test_chunks(CanvasCase):
 
     def test_chunks(self):
         x = self.add.chunks(range(100), 10)
-        self.assertEqual(chunks.from_dict(dict(x)), x)
+        self.assertEqual(
+            dict(chunks.from_dict(dict(x), app=self.app)), dict(x),
+        )
 
         self.assertTrue(x.group())
         self.assertEqual(len(x.group().tasks), 10)
@@ -193,10 +192,7 @@ class test_chunks(CanvasCase):
         gr.assert_called_with()
 
         self.app.conf.CELERY_ALWAYS_EAGER = True
-        try:
-            chunks.apply_chunks(**x['kwargs'])
-        finally:
-            self.app.conf.CELERY_ALWAYS_EAGER = False
+        chunks.apply_chunks(app=self.app, **x['kwargs'])
 
 
 class test_chain(CanvasCase):
@@ -214,10 +210,7 @@ class test_chain(CanvasCase):
 
     def test_always_eager(self):
         self.app.conf.CELERY_ALWAYS_EAGER = True
-        try:
-            self.assertEqual(~(self.add.s(4, 4) | self.add.s(8)), 16)
-        finally:
-            self.app.conf.CELERY_ALWAYS_EAGER = False
+        self.assertEqual(~(self.add.s(4, 4) | self.add.s(8)), 16)
 
     def test_apply(self):
         x = chain(self.add.s(4, 4), self.add.s(8), self.add.s(10))

+ 13 - 19
celery/tests/tasks/test_chord.py

@@ -150,7 +150,7 @@ class test_unlock_chord_task(ChordCase):
                     assert self.app.tasks['celery.chord_unlock'] is unlock
                     unlock(
                         'group_id', callback_s,
-                        result=[AsyncResult(r) for r in ['1', 2, 3]],
+                        result=[self.app.AsyncResult(r) for r in ['1', 2, 3]],
                         GroupResult=ResultCls, **kwargs
                     )
                 finally:
@@ -178,22 +178,19 @@ class test_chord(ChordCase):
     def test_eager(self):
         from celery import chord
 
-        @self.app.task()
+        @self.app.task(shared=False)
         def addX(x, y):
             return x + y
 
-        @self.app.task()
+        @self.app.task(shared=False)
         def sumX(n):
             return sum(n)
 
         self.app.conf.CELERY_ALWAYS_EAGER = True
-        try:
-            x = chord(addX.s(i, i) for i in range(10))
-            body = sumX.s()
-            result = x(body)
-            self.assertEqual(result.get(), sum(i + i for i in range(10)))
-        finally:
-            self.app.conf.CELERY_ALWAYS_EAGER = False
+        x = chord(addX.s(i, i) for i in range(10))
+        body = sumX.s()
+        result = x(body)
+        self.assertEqual(result.get(), sum(i + i for i in range(10)))
 
     def test_apply(self):
         self.app.conf.CELERY_ALWAYS_EAGER = False
@@ -219,15 +216,12 @@ class test_chord(ChordCase):
 class test_Chord_task(ChordCase):
 
     def test_run(self):
-        prev, self.app.backend = self.app.backend, Mock()
+        self.app.backend = Mock()
         self.app.backend.cleanup = Mock()
         self.app.backend.cleanup.__name__ = 'cleanup'
-        try:
-            Chord = self.app.tasks['celery.chord']
+        Chord = self.app.tasks['celery.chord']
 
-            body = dict()
-            Chord(group(self.add.subtask((i, i)) for i in range(5)), body)
-            Chord([self.add.subtask((j, j)) for j in range(5)], body)
-            self.assertEqual(self.app.backend.on_chord_apply.call_count, 2)
-        finally:
-            self.app.backend = prev
+        body = dict()
+        Chord(group(self.add.subtask((i, i)) for i in range(5)), body)
+        Chord([self.add.subtask((j, j)) for j in range(5)], body)
+        self.assertEqual(self.app.backend.on_chord_apply.call_count, 2)

+ 3 - 3
celery/tests/tasks/test_context.py

@@ -3,8 +3,8 @@ from __future__ import absolute_import
 
 from collections import Callable
 
-from celery.task.base import Context
-from celery.tests.case import Case
+from celery.app.task import Context
+from celery.tests.case import AppCase
 
 
 # Retreive the values of all context attributes as a
@@ -22,7 +22,7 @@ def get_context_as_dict(ctx, getter=getattr):
 default_context = get_context_as_dict(Context())
 
 
-class test_Context(Case):
+class test_Context(AppCase):
 
     def test_default_context(self):
         # A bit of a tautological test, since it uses the same

+ 16 - 16
celery/tests/tasks/test_result.py

@@ -10,15 +10,12 @@ from celery.result import (
     AsyncResult,
     EagerResult,
     TaskSetResult,
-    ResultSet,
-    #GroupResult,
     from_serializable,
 )
 from celery.utils import uuid
 from celery.utils.serialization import pickle
 
-from celery.tests.case import AppCase
-from celery.tests.case import skip_if_quick
+from celery.tests.case import AppCase, depends_on_current_app, skip_if_quick
 
 
 def mock_task(name, state, result):
@@ -56,7 +53,7 @@ class test_AsyncResult(AppCase):
         for task in (self.task1, self.task2, self.task3, self.task4):
             save_result(self.app, task)
 
-        @self.app.task()
+        @self.app.task(shared=False)
         def mytask():
             pass
         self.mytask = mytask
@@ -150,6 +147,7 @@ class test_AsyncResult(AppCase):
     def test_eq_not_implemented(self):
         self.assertFalse(self.app.AsyncResult('1') == object())
 
+    @depends_on_current_app
     def test_reduce(self):
         a1 = self.app.AsyncResult('uuid', task_name=self.mytask.name)
         restored = pickle.loads(pickle.dumps(a1))
@@ -261,15 +259,15 @@ class test_AsyncResult(AppCase):
 class test_ResultSet(AppCase):
 
     def test_resultset_repr(self):
-        self.assertTrue(repr(ResultSet(
+        self.assertTrue(repr(self.app.ResultSet(
             [self.app.AsyncResult(t) for t in ['1', '2', '3']])))
 
     def test_eq_other(self):
-        self.assertFalse(ResultSet([1, 3, 3]) == 1)
-        self.assertTrue(ResultSet([1]) == ResultSet([1]))
+        self.assertFalse(self.app.ResultSet([1, 3, 3]) == 1)
+        self.assertTrue(self.app.ResultSet([1]) == self.app.ResultSet([1]))
 
     def test_get(self):
-        x = ResultSet([self.app.AsyncResult(t) for t in [1, 2, 3]])
+        x = self.app.ResultSet([self.app.AsyncResult(t) for t in [1, 2, 3]])
         b = x.results[0].backend = Mock()
         b.supports_native_join = False
         x.join_native = Mock()
@@ -281,7 +279,7 @@ class test_ResultSet(AppCase):
         self.assertTrue(x.join_native.called)
 
     def test_add(self):
-        x = ResultSet([1])
+        x = self.app.ResultSet([1])
         x.add(2)
         self.assertEqual(len(x), 2)
         x.add(2)
@@ -311,7 +309,7 @@ class test_ResultSet(AppCase):
         ready.return_value = False
         ready.side_effect = se
 
-        x = ResultSet([r1, r2])
+        x = self.app.ResultSet([r1, r2])
         with self.dummy_copy():
             with patch('celery.result.time') as _time:
                 with self.assertRaises(KeyError):
@@ -330,14 +328,14 @@ class test_ResultSet(AppCase):
         r1 = self.app.AsyncResult(uuid)
         r1.ready = Mock()
         r1.ready.return_value = False
-        x = ResultSet([r1])
+        x = self.app.ResultSet([r1])
         with self.dummy_copy():
             with patch('celery.result.time'):
                 with self.assertRaises(TimeoutError):
                     list(x.iterate(timeout=1))
 
     def test_add_discard(self):
-        x = ResultSet([])
+        x = self.app.ResultSet([])
         x.add(self.app.AsyncResult('1'))
         self.assertIn(self.app.AsyncResult('1'), x.results)
         x.discard(self.app.AsyncResult('1'))
@@ -348,7 +346,7 @@ class test_ResultSet(AppCase):
         x.update([self.app.AsyncResult('2')])
 
     def test_clear(self):
-        x = ResultSet([])
+        x = self.app.ResultSet([])
         r = x.results
         x.clear()
         self.assertIs(x.results, r)
@@ -432,6 +430,7 @@ class test_GroupResult(AppCase):
             uuid(), make_mock_group(self.app, self.size),
         )
 
+    @depends_on_current_app
     def test_is_pickleable(self):
         ts = self.app.GroupResult(uuid(), [self.app.AsyncResult(uuid())])
         self.assertEqual(pickle.loads(pickle.dumps(ts)), ts)
@@ -444,6 +443,7 @@ class test_GroupResult(AppCase):
     def test_eq_other(self):
         self.assertFalse(self.ts == 1)
 
+    @depends_on_current_app
     def test_reduce(self):
         self.assertTrue(pickle.loads(pickle.dumps(self.ts)))
 
@@ -660,7 +660,7 @@ class test_EagerResult(AppCase):
 
     def setup(self):
 
-        @self.app.task
+        @self.app.task(shared=False)
         def raising(x, y):
             raise KeyError(x, y)
         self.raising = raising
@@ -703,7 +703,7 @@ class test_serializable(AppCase):
 
     def test_compat(self):
         uid = uuid()
-        x = from_serializable([uid, []])
+        x = from_serializable([uid, []], app=self.app)
         self.assertEqual(x.id, uid)
 
     def test_GroupResult(self):

+ 49 - 815
celery/tests/tasks/test_tasks.py

@@ -1,29 +1,20 @@
 from __future__ import absolute_import
-import time
+
 from collections import Callable
 from datetime import datetime, timedelta
-from functools import wraps
 from mock import patch
-from nose import SkipTest
-from pickle import loads, dumps
 
 from kombu import Queue
 
 from celery import Task
 
-from celery.task import (
-    periodic_task,
-    PeriodicTask
-)
 from celery.exceptions import RetryTaskError
-from celery.execute import send_task
 from celery.five import items, range, string_t
 from celery.result import EagerResult
-from celery.schedules import crontab, crontab_parser, ParseException
 from celery.utils import uuid
-from celery.utils.timeutils import parse_iso8601, timedelta_seconds
+from celery.utils.timeutils import parse_iso8601
 
-from celery.tests.case import AppCase
+from celery.tests.case import AppCase, depends_on_current_app
 
 
 def return_True(*args, **kwargs):
@@ -49,8 +40,7 @@ class MockApplyTask(Task):
 class TasksCase(AppCase):
 
     def setup(self):
-
-        self.return_True_task = self.app.task(shared=False)(return_True)
+        self.mytask = self.app.task(shared=False)(return_True)
 
         @self.app.task(bind=True, count=0, shared=False)
         def increment_counter(self, increment_by=1):
@@ -225,8 +215,8 @@ class test_tasks(TasksCase):
     def now(self):
         return self.app.now()
 
+    @depends_on_current_app
     def test_unpickle_task(self):
-        self.app.set_current()
         import pickle
 
         @self.app.task(shared=True)
@@ -234,10 +224,6 @@ class test_tasks(TasksCase):
             pass
         self.assertIs(pickle.loads(pickle.dumps(xxx)), xxx.app.tasks[xxx.name])
 
-    def create_task(self, name):
-        return self.app.task(__module__=self.__module__,
-                             shared=False, name=name)(return_True)
-
     def test_AsyncResult(self):
         task_id = uuid()
         result = self.retry_task.AsyncResult(task_id)
@@ -265,6 +251,7 @@ class test_tasks(TasksCase):
     def test_incomplete_task_cls(self):
 
         class IncompleteTask(Task):
+            app = self.app
             name = 'c.unittest.t.itask'
 
         with self.assertRaises(NotImplementedError):
@@ -279,11 +266,11 @@ class test_tasks(TasksCase):
             self.increment_counter.apply_async('str', {})
 
     def test_regular_task(self):
-        T1 = self.create_task('c.unittest.t.t1')
-        self.assertIsInstance(T1, Task)
-        self.assertTrue(T1.run())
-        self.assertTrue(isinstance(T1, Callable), 'Task class is callable()')
-        self.assertTrue(T1(), 'Task class runs run() when called')
+        self.assertIsInstance(self.mytask, Task)
+        self.assertTrue(self.mytask.run())
+        self.assertTrue(isinstance(self.mytask, Callable),
+                        'Task class is callable()')
+        self.assertTrue(self.mytask(), 'Task class runs run() when called')
 
         with self.app.connection_or_acquire() as conn:
             consumer = self.app.amqp.TaskConsumer(conn)
@@ -294,54 +281,57 @@ class test_tasks(TasksCase):
             self.app.amqp.TaskConsumer(conn, queues=[Queue('foo')])
 
             # Without arguments.
-            presult = T1.delay()
-            self.assertNextTaskDataEqual(consumer, presult, T1.name)
+            presult = self.mytask.delay()
+            self.assertNextTaskDataEqual(consumer, presult, self.mytask.name)
 
             # With arguments.
-            presult2 = T1.apply_async(kwargs=dict(name='George Costanza'))
+            presult2 = self.mytask.apply_async(
+                kwargs=dict(name='George Costanza'),
+            )
             self.assertNextTaskDataEqual(
-                consumer, presult2, T1.name, name='George Costanza',
+                consumer, presult2, self.mytask.name, name='George Costanza',
             )
 
             # send_task
-            sresult = send_task(T1.name, kwargs=dict(name='Elaine M. Benes'))
+            sresult = self.app.send_task(self.mytask.name,
+                                         kwargs=dict(name='Elaine M. Benes'))
             self.assertNextTaskDataEqual(
-                consumer, sresult, T1.name, name='Elaine M. Benes',
+                consumer, sresult, self.mytask.name, name='Elaine M. Benes',
             )
 
             # With eta.
-            presult2 = T1.apply_async(
+            presult2 = self.mytask.apply_async(
                 kwargs=dict(name='George Costanza'),
                 eta=self.now() + timedelta(days=1),
                 expires=self.now() + timedelta(days=2),
             )
             self.assertNextTaskDataEqual(
-                consumer, presult2, T1.name,
+                consumer, presult2, self.mytask.name,
                 name='George Costanza', test_eta=True, test_expires=True,
             )
 
             # With countdown.
-            presult2 = T1.apply_async(kwargs=dict(name='George Costanza'),
-                                      countdown=10, expires=12)
+            presult2 = self.mytask.apply_async(
+                kwargs=dict(name='George Costanza'), countdown=10, expires=12,
+            )
             self.assertNextTaskDataEqual(
-                consumer, presult2, T1.name,
+                consumer, presult2, self.mytask.name,
                 name='George Costanza', test_eta=True, test_expires=True,
             )
 
             # Discarding all tasks.
             consumer.purge()
-            T1.apply_async()
+            self.mytask.apply_async()
             self.assertEqual(consumer.purge(), 1)
             self.assertIsNone(consumer.queues[0].get())
 
             self.assertFalse(presult.successful())
-            T1.backend.mark_as_done(presult.id, result=None)
+            self.mytask.backend.mark_as_done(presult.id, result=None)
             self.assertTrue(presult.successful())
 
     def test_repr_v2_compat(self):
-        task = type(self.create_task('c.unittest.v2c')._get_current_object())
-        task.__v2_compat__ = True
-        self.assertIn('v2 compatible', repr(task))
+        self.mytask.__v2_compat__ = True
+        self.assertIn('v2 compatible', repr(self.mytask))
 
     def test_apply_with_self(self):
 
@@ -354,46 +344,43 @@ class test_tasks(TasksCase):
         self.assertEqual(tawself(), 42)
 
     def test_context_get(self):
-        task = self.create_task('c.unittest.t.c.g')
-        task.push_request()
+        self.mytask.push_request()
         try:
-            request = task.request
+            request = self.mytask.request
             request.foo = 32
             self.assertEqual(request.get('foo'), 32)
             self.assertEqual(request.get('bar', 36), 36)
             request.clear()
         finally:
-            task.pop_request()
+            self.mytask.pop_request()
 
     def test_task_class_repr(self):
-        task = self.create_task('c.unittest.t.repr')
-        self.assertIn('class Task of', repr(task.app.Task))
-        prev, task.app.Task._app = task.app.Task._app, None
-        try:
-            self.assertIn('unbound', repr(task.app.Task, ))
-        finally:
-            task.app.Task._app = prev
+        self.assertIn('class Task of', repr(self.mytask.app.Task))
+        self.mytask.app.Task._app = None
+        self.assertIn('unbound', repr(self.mytask.app.Task, ))
 
     def test_bind_no_magic_kwargs(self):
-        task = self.create_task('c.unittest.t.magic_kwargs')
-        task.accept_magic_kwargs = None
-        task.bind(task.app)
+        self.mytask.accept_magic_kwargs = None
+        self.mytask.bind(self.mytask.app)
 
     def test_annotate(self):
         with patch('celery.app.task.resolve_all_annotations') as anno:
             anno.return_value = [{'FOO': 'BAR'}]
-            Task.annotate()
-            self.assertEqual(Task.FOO, 'BAR')
+
+            @self.app.task(shared=False)
+            def task():
+                pass
+            task.annotate()
+            self.assertEqual(task.FOO, 'BAR')
 
     def test_after_return(self):
-        task = self.create_task('c.unittest.t.after_return')
-        task.push_request()
+        self.mytask.push_request()
         try:
-            task.request.chord = self.return_True_task.s()
-            task.after_return('SUCCESS', 1.0, 'foobar', (), {}, None)
-            task.request.clear()
+            self.mytask.request.chord = self.mytask.s()
+            self.mytask.after_return('SUCCESS', 1.0, 'foobar', (), {}, None)
+            self.mytask.request.clear()
         finally:
-            task.pop_request()
+            self.mytask.pop_request()
 
     def test_send_task_sent_event(self):
         with self.app.connection() as conn:
@@ -471,756 +458,3 @@ class test_apply_task(TasksCase):
         self.assertTrue(f.traceback)
         with self.assertRaises(KeyError):
             f.get()
-
-
-@periodic_task(run_every=timedelta(hours=1))
-def my_periodic():
-    pass
-
-
-class test_periodic_tasks(AppCase):
-
-    def now(self):
-        return self.app.now()
-
-    def test_must_have_run_every(self):
-        with self.assertRaises(NotImplementedError):
-            type('Foo', (PeriodicTask, ), {'__module__': __name__})
-
-    def test_remaining_estimate(self):
-        s = my_periodic.run_every
-        self.assertIsInstance(
-            s.remaining_estimate(s.maybe_make_aware(self.now())),
-            timedelta)
-
-    def test_is_due_not_due(self):
-        due, remaining = my_periodic.run_every.is_due(self.now())
-        self.assertFalse(due)
-        # This assertion may fail if executed in the
-        # first minute of an hour, thus 59 instead of 60
-        self.assertGreater(remaining, 59)
-
-    def test_is_due(self):
-        p = my_periodic
-        due, remaining = p.run_every.is_due(
-            self.now() - p.run_every.run_every,
-        )
-        self.assertTrue(due)
-        self.assertEqual(remaining,
-                         timedelta_seconds(p.run_every.run_every))
-
-    def test_schedule_repr(self):
-        p = my_periodic
-        self.assertTrue(repr(p.run_every))
-
-
-@periodic_task(run_every=crontab())
-def every_minute():
-    pass
-
-
-@periodic_task(run_every=crontab(minute='*/15'))
-def quarterly():
-    pass
-
-
-@periodic_task(run_every=crontab(minute=30))
-def hourly():
-    pass
-
-
-@periodic_task(run_every=crontab(hour=7, minute=30))
-def daily():
-    pass
-
-
-@periodic_task(run_every=crontab(hour=7, minute=30,
-                                 day_of_week='thursday'))
-def weekly():
-    pass
-
-
-@periodic_task(run_every=crontab(hour=7, minute=30,
-                                 day_of_week='thursday',
-                                 day_of_month='8-14'))
-def monthly():
-    pass
-
-
-@periodic_task(run_every=crontab(hour=22,
-                                 day_of_week='*',
-                                 month_of_year='2',
-                                 day_of_month='26,27,28'))
-def monthly_moy():
-    pass
-
-
-@periodic_task(run_every=crontab(hour=7, minute=30,
-                                 day_of_week='thursday',
-                                 day_of_month='8-14',
-                                 month_of_year=3))
-def yearly():
-    pass
-
-
-def patch_crontab_nowfun(cls, retval):
-
-    def create_patcher(fun):
-
-        @wraps(fun)
-        def __inner(*args, **kwargs):
-            prev_nowfun = cls.run_every.nowfun
-            cls.run_every.nowfun = lambda: retval
-            try:
-                return fun(*args, **kwargs)
-            finally:
-                cls.run_every.nowfun = prev_nowfun
-
-        return __inner
-
-    return create_patcher
-
-
-class test_crontab_parser(AppCase):
-
-    def test_crontab_reduce(self):
-        self.assertTrue(loads(dumps(crontab('*'))))
-
-    def test_range_steps_not_enough(self):
-        with self.assertRaises(crontab_parser.ParseException):
-            crontab_parser(24)._range_steps([1])
-
-    def test_parse_star(self):
-        self.assertEqual(crontab_parser(24).parse('*'), set(range(24)))
-        self.assertEqual(crontab_parser(60).parse('*'), set(range(60)))
-        self.assertEqual(crontab_parser(7).parse('*'), set(range(7)))
-        self.assertEqual(crontab_parser(31, 1).parse('*'),
-                         set(range(1, 31 + 1)))
-        self.assertEqual(crontab_parser(12, 1).parse('*'),
-                         set(range(1, 12 + 1)))
-
-    def test_parse_range(self):
-        self.assertEqual(crontab_parser(60).parse('1-10'),
-                         set(range(1, 10 + 1)))
-        self.assertEqual(crontab_parser(24).parse('0-20'),
-                         set(range(0, 20 + 1)))
-        self.assertEqual(crontab_parser().parse('2-10'),
-                         set(range(2, 10 + 1)))
-        self.assertEqual(crontab_parser(60, 1).parse('1-10'),
-                         set(range(1, 10 + 1)))
-
-    def test_parse_range_wraps(self):
-        self.assertEqual(crontab_parser(12).parse('11-1'),
-                         set([11, 0, 1]))
-        self.assertEqual(crontab_parser(60, 1).parse('2-1'),
-                         set(range(1, 60 + 1)))
-
-    def test_parse_groups(self):
-        self.assertEqual(crontab_parser().parse('1,2,3,4'),
-                         set([1, 2, 3, 4]))
-        self.assertEqual(crontab_parser().parse('0,15,30,45'),
-                         set([0, 15, 30, 45]))
-        self.assertEqual(crontab_parser(min_=1).parse('1,2,3,4'),
-                         set([1, 2, 3, 4]))
-
-    def test_parse_steps(self):
-        self.assertEqual(crontab_parser(8).parse('*/2'),
-                         set([0, 2, 4, 6]))
-        self.assertEqual(crontab_parser().parse('*/2'),
-                         set(i * 2 for i in range(30)))
-        self.assertEqual(crontab_parser().parse('*/3'),
-                         set(i * 3 for i in range(20)))
-        self.assertEqual(crontab_parser(8, 1).parse('*/2'),
-                         set([1, 3, 5, 7]))
-        self.assertEqual(crontab_parser(min_=1).parse('*/2'),
-                         set(i * 2 + 1 for i in range(30)))
-        self.assertEqual(crontab_parser(min_=1).parse('*/3'),
-                         set(i * 3 + 1 for i in range(20)))
-
-    def test_parse_composite(self):
-        self.assertEqual(crontab_parser(8).parse('*/2'), set([0, 2, 4, 6]))
-        self.assertEqual(crontab_parser().parse('2-9/5'), set([2, 7]))
-        self.assertEqual(crontab_parser().parse('2-10/5'), set([2, 7]))
-        self.assertEqual(
-            crontab_parser(min_=1).parse('55-5/3'),
-            set([55, 58, 1, 4]),
-        )
-        self.assertEqual(crontab_parser().parse('2-11/5,3'), set([2, 3, 7]))
-        self.assertEqual(
-            crontab_parser().parse('2-4/3,*/5,0-21/4'),
-            set([0, 2, 4, 5, 8, 10, 12, 15, 16,
-                 20, 25, 30, 35, 40, 45, 50, 55]),
-        )
-        self.assertEqual(
-            crontab_parser().parse('1-9/2'),
-            set([1, 3, 5, 7, 9]),
-        )
-        self.assertEqual(crontab_parser(8, 1).parse('*/2'), set([1, 3, 5, 7]))
-        self.assertEqual(crontab_parser(min_=1).parse('2-9/5'), set([2, 7]))
-        self.assertEqual(crontab_parser(min_=1).parse('2-10/5'), set([2, 7]))
-        self.assertEqual(
-            crontab_parser(min_=1).parse('2-11/5,3'),
-            set([2, 3, 7]),
-        )
-        self.assertEqual(
-            crontab_parser(min_=1).parse('2-4/3,*/5,1-21/4'),
-            set([1, 2, 5, 6, 9, 11, 13, 16, 17,
-                 21, 26, 31, 36, 41, 46, 51, 56]),
-        )
-        self.assertEqual(
-            crontab_parser(min_=1).parse('1-9/2'),
-            set([1, 3, 5, 7, 9]),
-        )
-
-    def test_parse_errors_on_empty_string(self):
-        with self.assertRaises(ParseException):
-            crontab_parser(60).parse('')
-
-    def test_parse_errors_on_empty_group(self):
-        with self.assertRaises(ParseException):
-            crontab_parser(60).parse('1,,2')
-
-    def test_parse_errors_on_empty_steps(self):
-        with self.assertRaises(ParseException):
-            crontab_parser(60).parse('*/')
-
-    def test_parse_errors_on_negative_number(self):
-        with self.assertRaises(ParseException):
-            crontab_parser(60).parse('-20')
-
-    def test_parse_errors_on_lt_min(self):
-        crontab_parser(min_=1).parse('1')
-        with self.assertRaises(ValueError):
-            crontab_parser(12, 1).parse('0')
-        with self.assertRaises(ValueError):
-            crontab_parser(24, 1).parse('12-0')
-
-    def test_parse_errors_on_gt_max(self):
-        crontab_parser(1).parse('0')
-        with self.assertRaises(ValueError):
-            crontab_parser(1).parse('1')
-        with self.assertRaises(ValueError):
-            crontab_parser(60).parse('61-0')
-
-    def test_expand_cronspec_eats_iterables(self):
-        self.assertEqual(crontab._expand_cronspec(iter([1, 2, 3]), 100),
-                         set([1, 2, 3]))
-        self.assertEqual(crontab._expand_cronspec(iter([1, 2, 3]), 100, 1),
-                         set([1, 2, 3]))
-
-    def test_expand_cronspec_invalid_type(self):
-        with self.assertRaises(TypeError):
-            crontab._expand_cronspec(object(), 100)
-
-    def test_repr(self):
-        self.assertIn('*', repr(crontab('*')))
-
-    def test_eq(self):
-        self.assertEqual(crontab(day_of_week='1, 2'),
-                         crontab(day_of_week='1-2'))
-        self.assertEqual(crontab(day_of_month='1, 16, 31'),
-                         crontab(day_of_month='*/15'))
-        self.assertEqual(crontab(minute='1', hour='2', day_of_week='5',
-                                 day_of_month='10', month_of_year='5'),
-                         crontab(minute='1', hour='2', day_of_week='5',
-                                 day_of_month='10', month_of_year='5'))
-        self.assertNotEqual(crontab(minute='1'), crontab(minute='2'))
-        self.assertNotEqual(crontab(month_of_year='1'),
-                            crontab(month_of_year='2'))
-        self.assertFalse(object() == crontab(minute='1'))
-        self.assertFalse(crontab(minute='1') == object())
-
-
-class test_crontab_remaining_estimate(AppCase):
-
-    def next_ocurrance(self, crontab, now):
-        crontab.nowfun = lambda: now
-        return now + crontab.remaining_estimate(now)
-
-    def test_next_minute(self):
-        next = self.next_ocurrance(crontab(),
-                                   datetime(2010, 9, 11, 14, 30, 15))
-        self.assertEqual(next, datetime(2010, 9, 11, 14, 31))
-
-    def test_not_next_minute(self):
-        next = self.next_ocurrance(crontab(),
-                                   datetime(2010, 9, 11, 14, 59, 15))
-        self.assertEqual(next, datetime(2010, 9, 11, 15, 0))
-
-    def test_this_hour(self):
-        next = self.next_ocurrance(crontab(minute=[5, 42]),
-                                   datetime(2010, 9, 11, 14, 30, 15))
-        self.assertEqual(next, datetime(2010, 9, 11, 14, 42))
-
-    def test_not_this_hour(self):
-        next = self.next_ocurrance(crontab(minute=[5, 10, 15]),
-                                   datetime(2010, 9, 11, 14, 30, 15))
-        self.assertEqual(next, datetime(2010, 9, 11, 15, 5))
-
-    def test_today(self):
-        next = self.next_ocurrance(crontab(minute=[5, 42], hour=[12, 17]),
-                                   datetime(2010, 9, 11, 14, 30, 15))
-        self.assertEqual(next, datetime(2010, 9, 11, 17, 5))
-
-    def test_not_today(self):
-        next = self.next_ocurrance(crontab(minute=[5, 42], hour=[12]),
-                                   datetime(2010, 9, 11, 14, 30, 15))
-        self.assertEqual(next, datetime(2010, 9, 12, 12, 5))
-
-    def test_weekday(self):
-        next = self.next_ocurrance(crontab(minute=30,
-                                           hour=14,
-                                           day_of_week='sat'),
-                                   datetime(2010, 9, 11, 14, 30, 15))
-        self.assertEqual(next, datetime(2010, 9, 18, 14, 30))
-
-    def test_not_weekday(self):
-        next = self.next_ocurrance(crontab(minute=[5, 42],
-                                           day_of_week='mon-fri'),
-                                   datetime(2010, 9, 11, 14, 30, 15))
-        self.assertEqual(next, datetime(2010, 9, 13, 0, 5))
-
-    def test_monthday(self):
-        next = self.next_ocurrance(crontab(minute=30,
-                                           hour=14,
-                                           day_of_month=18),
-                                   datetime(2010, 9, 11, 14, 30, 15))
-        self.assertEqual(next, datetime(2010, 9, 18, 14, 30))
-
-    def test_not_monthday(self):
-        next = self.next_ocurrance(crontab(minute=[5, 42],
-                                           day_of_month=29),
-                                   datetime(2010, 1, 22, 14, 30, 15))
-        self.assertEqual(next, datetime(2010, 1, 29, 0, 5))
-
-    def test_weekday_monthday(self):
-        next = self.next_ocurrance(crontab(minute=30,
-                                           hour=14,
-                                           day_of_week='mon',
-                                           day_of_month=18),
-                                   datetime(2010, 1, 18, 14, 30, 15))
-        self.assertEqual(next, datetime(2010, 10, 18, 14, 30))
-
-    def test_monthday_not_weekday(self):
-        next = self.next_ocurrance(crontab(minute=[5, 42],
-                                           day_of_week='sat',
-                                           day_of_month=29),
-                                   datetime(2010, 1, 29, 0, 5, 15))
-        self.assertEqual(next, datetime(2010, 5, 29, 0, 5))
-
-    def test_weekday_not_monthday(self):
-        next = self.next_ocurrance(crontab(minute=[5, 42],
-                                           day_of_week='mon',
-                                           day_of_month=18),
-                                   datetime(2010, 1, 11, 0, 5, 15))
-        self.assertEqual(next, datetime(2010, 1, 18, 0, 5))
-
-    def test_not_weekday_not_monthday(self):
-        next = self.next_ocurrance(crontab(minute=[5, 42],
-                                           day_of_week='mon',
-                                           day_of_month=18),
-                                   datetime(2010, 1, 10, 0, 5, 15))
-        self.assertEqual(next, datetime(2010, 1, 18, 0, 5))
-
-    def test_leapday(self):
-        next = self.next_ocurrance(crontab(minute=30,
-                                           hour=14,
-                                           day_of_month=29),
-                                   datetime(2012, 1, 29, 14, 30, 15))
-        self.assertEqual(next, datetime(2012, 2, 29, 14, 30))
-
-    def test_not_leapday(self):
-        next = self.next_ocurrance(crontab(minute=30,
-                                           hour=14,
-                                           day_of_month=29),
-                                   datetime(2010, 1, 29, 14, 30, 15))
-        self.assertEqual(next, datetime(2010, 3, 29, 14, 30))
-
-    def test_weekmonthdayyear(self):
-        next = self.next_ocurrance(crontab(minute=30,
-                                           hour=14,
-                                           day_of_week='fri',
-                                           day_of_month=29,
-                                           month_of_year=1),
-                                   datetime(2010, 1, 22, 14, 30, 15))
-        self.assertEqual(next, datetime(2010, 1, 29, 14, 30))
-
-    def test_monthdayyear_not_week(self):
-        next = self.next_ocurrance(crontab(minute=[5, 42],
-                                           day_of_week='wed,thu',
-                                           day_of_month=29,
-                                           month_of_year='1,4,7'),
-                                   datetime(2010, 1, 29, 14, 30, 15))
-        self.assertEqual(next, datetime(2010, 4, 29, 0, 5))
-
-    def test_weekdaymonthyear_not_monthday(self):
-        next = self.next_ocurrance(crontab(minute=30,
-                                           hour=14,
-                                           day_of_week='fri',
-                                           day_of_month=29,
-                                           month_of_year='1-10'),
-                                   datetime(2010, 1, 29, 14, 30, 15))
-        self.assertEqual(next, datetime(2010, 10, 29, 14, 30))
-
-    def test_weekmonthday_not_monthyear(self):
-        next = self.next_ocurrance(crontab(minute=[5, 42],
-                                           day_of_week='fri',
-                                           day_of_month=29,
-                                           month_of_year='2-10'),
-                                   datetime(2010, 1, 29, 14, 30, 15))
-        self.assertEqual(next, datetime(2010, 10, 29, 0, 5))
-
-    def test_weekday_not_monthdayyear(self):
-        next = self.next_ocurrance(crontab(minute=[5, 42],
-                                           day_of_week='mon',
-                                           day_of_month=18,
-                                           month_of_year='2-10'),
-                                   datetime(2010, 1, 11, 0, 5, 15))
-        self.assertEqual(next, datetime(2010, 10, 18, 0, 5))
-
-    def test_monthday_not_weekdaymonthyear(self):
-        next = self.next_ocurrance(crontab(minute=[5, 42],
-                                           day_of_week='mon',
-                                           day_of_month=29,
-                                           month_of_year='2-4'),
-                                   datetime(2010, 1, 29, 0, 5, 15))
-        self.assertEqual(next, datetime(2010, 3, 29, 0, 5))
-
-    def test_monthyear_not_weekmonthday(self):
-        next = self.next_ocurrance(crontab(minute=[5, 42],
-                                           day_of_week='mon',
-                                           day_of_month=29,
-                                           month_of_year='2-4'),
-                                   datetime(2010, 2, 28, 0, 5, 15))
-        self.assertEqual(next, datetime(2010, 3, 29, 0, 5))
-
-    def test_not_weekmonthdayyear(self):
-        next = self.next_ocurrance(crontab(minute=[5, 42],
-                                           day_of_week='fri,sat',
-                                           day_of_month=29,
-                                           month_of_year='2-10'),
-                                   datetime(2010, 1, 28, 14, 30, 15))
-        self.assertEqual(next, datetime(2010, 5, 29, 0, 5))
-
-
-class test_crontab_is_due(AppCase):
-
-    def getnow(self):
-        return self.app.now()
-
-    def setup(self):
-        self.now = self.getnow()
-        self.next_minute = 60 - self.now.second - 1e-6 * self.now.microsecond
-
-    def test_default_crontab_spec(self):
-        c = crontab()
-        self.assertEqual(c.minute, set(range(60)))
-        self.assertEqual(c.hour, set(range(24)))
-        self.assertEqual(c.day_of_week, set(range(7)))
-        self.assertEqual(c.day_of_month, set(range(1, 32)))
-        self.assertEqual(c.month_of_year, set(range(1, 13)))
-
-    def test_simple_crontab_spec(self):
-        c = crontab(minute=30)
-        self.assertEqual(c.minute, set([30]))
-        self.assertEqual(c.hour, set(range(24)))
-        self.assertEqual(c.day_of_week, set(range(7)))
-        self.assertEqual(c.day_of_month, set(range(1, 32)))
-        self.assertEqual(c.month_of_year, set(range(1, 13)))
-
-    def test_crontab_spec_minute_formats(self):
-        c = crontab(minute=30)
-        self.assertEqual(c.minute, set([30]))
-        c = crontab(minute='30')
-        self.assertEqual(c.minute, set([30]))
-        c = crontab(minute=(30, 40, 50))
-        self.assertEqual(c.minute, set([30, 40, 50]))
-        c = crontab(minute=set([30, 40, 50]))
-        self.assertEqual(c.minute, set([30, 40, 50]))
-
-    def test_crontab_spec_invalid_minute(self):
-        with self.assertRaises(ValueError):
-            crontab(minute=60)
-        with self.assertRaises(ValueError):
-            crontab(minute='0-100')
-
-    def test_crontab_spec_hour_formats(self):
-        c = crontab(hour=6)
-        self.assertEqual(c.hour, set([6]))
-        c = crontab(hour='5')
-        self.assertEqual(c.hour, set([5]))
-        c = crontab(hour=(4, 8, 12))
-        self.assertEqual(c.hour, set([4, 8, 12]))
-
-    def test_crontab_spec_invalid_hour(self):
-        with self.assertRaises(ValueError):
-            crontab(hour=24)
-        with self.assertRaises(ValueError):
-            crontab(hour='0-30')
-
-    def test_crontab_spec_dow_formats(self):
-        c = crontab(day_of_week=5)
-        self.assertEqual(c.day_of_week, set([5]))
-        c = crontab(day_of_week='5')
-        self.assertEqual(c.day_of_week, set([5]))
-        c = crontab(day_of_week='fri')
-        self.assertEqual(c.day_of_week, set([5]))
-        c = crontab(day_of_week='tuesday,sunday,fri')
-        self.assertEqual(c.day_of_week, set([0, 2, 5]))
-        c = crontab(day_of_week='mon-fri')
-        self.assertEqual(c.day_of_week, set([1, 2, 3, 4, 5]))
-        c = crontab(day_of_week='*/2')
-        self.assertEqual(c.day_of_week, set([0, 2, 4, 6]))
-
-    def test_crontab_spec_invalid_dow(self):
-        with self.assertRaises(ValueError):
-            crontab(day_of_week='fooday-barday')
-        with self.assertRaises(ValueError):
-            crontab(day_of_week='1,4,foo')
-        with self.assertRaises(ValueError):
-            crontab(day_of_week='7')
-        with self.assertRaises(ValueError):
-            crontab(day_of_week='12')
-
-    def test_crontab_spec_dom_formats(self):
-        c = crontab(day_of_month=5)
-        self.assertEqual(c.day_of_month, set([5]))
-        c = crontab(day_of_month='5')
-        self.assertEqual(c.day_of_month, set([5]))
-        c = crontab(day_of_month='2,4,6')
-        self.assertEqual(c.day_of_month, set([2, 4, 6]))
-        c = crontab(day_of_month='*/5')
-        self.assertEqual(c.day_of_month, set([1, 6, 11, 16, 21, 26, 31]))
-
-    def test_crontab_spec_invalid_dom(self):
-        with self.assertRaises(ValueError):
-            crontab(day_of_month=0)
-        with self.assertRaises(ValueError):
-            crontab(day_of_month='0-10')
-        with self.assertRaises(ValueError):
-            crontab(day_of_month=32)
-        with self.assertRaises(ValueError):
-            crontab(day_of_month='31,32')
-
-    def test_crontab_spec_moy_formats(self):
-        c = crontab(month_of_year=1)
-        self.assertEqual(c.month_of_year, set([1]))
-        c = crontab(month_of_year='1')
-        self.assertEqual(c.month_of_year, set([1]))
-        c = crontab(month_of_year='2,4,6')
-        self.assertEqual(c.month_of_year, set([2, 4, 6]))
-        c = crontab(month_of_year='*/2')
-        self.assertEqual(c.month_of_year, set([1, 3, 5, 7, 9, 11]))
-        c = crontab(month_of_year='2-12/2')
-        self.assertEqual(c.month_of_year, set([2, 4, 6, 8, 10, 12]))
-
-    def test_crontab_spec_invalid_moy(self):
-        with self.assertRaises(ValueError):
-            crontab(month_of_year=0)
-        with self.assertRaises(ValueError):
-            crontab(month_of_year='0-5')
-        with self.assertRaises(ValueError):
-            crontab(month_of_year=13)
-        with self.assertRaises(ValueError):
-            crontab(month_of_year='12,13')
-
-    def seconds_almost_equal(self, a, b, precision):
-        for index, skew in enumerate((+0.1, 0, -0.1)):
-            try:
-                self.assertAlmostEqual(a, b + skew, precision)
-            except AssertionError:
-                if index + 1 >= 3:
-                    raise
-            else:
-                break
-
-    def assertRelativedelta(self, due, last_ran):
-        try:
-            from dateutil.relativedelta import relativedelta
-        except ImportError:
-            return
-        l1, d1, n1 = due.run_every.remaining_delta(last_ran)
-        l2, d2, n2 = due.run_every.remaining_delta(last_ran,
-                                                   ffwd=relativedelta)
-        if not isinstance(d1, relativedelta):
-            self.assertEqual(l1, l2)
-            for field, value in items(d1._fields()):
-                self.assertEqual(getattr(d1, field), value)
-            self.assertFalse(d2.years)
-            self.assertFalse(d2.months)
-            self.assertFalse(d2.days)
-            self.assertFalse(d2.leapdays)
-            self.assertFalse(d2.hours)
-            self.assertFalse(d2.minutes)
-            self.assertFalse(d2.seconds)
-            self.assertFalse(d2.microseconds)
-
-    def test_every_minute_execution_is_due(self):
-        last_ran = self.now - timedelta(seconds=61)
-        due, remaining = every_minute.run_every.is_due(last_ran)
-        self.assertRelativedelta(every_minute, last_ran)
-        self.assertTrue(due)
-        self.seconds_almost_equal(remaining, self.next_minute, 1)
-
-    def test_every_minute_execution_is_not_due(self):
-        last_ran = self.now - timedelta(seconds=self.now.second)
-        due, remaining = every_minute.run_every.is_due(last_ran)
-        self.assertFalse(due)
-        self.seconds_almost_equal(remaining, self.next_minute, 1)
-
-    # 29th of May 2010 is a saturday
-    @patch_crontab_nowfun(hourly, datetime(2010, 5, 29, 10, 30))
-    def test_execution_is_due_on_saturday(self):
-        last_ran = self.now - timedelta(seconds=61)
-        due, remaining = every_minute.run_every.is_due(last_ran)
-        self.assertTrue(due)
-        self.seconds_almost_equal(remaining, self.next_minute, 1)
-
-    # 30th of May 2010 is a sunday
-    @patch_crontab_nowfun(hourly, datetime(2010, 5, 30, 10, 30))
-    def test_execution_is_due_on_sunday(self):
-        last_ran = self.now - timedelta(seconds=61)
-        due, remaining = every_minute.run_every.is_due(last_ran)
-        self.assertTrue(due)
-        self.seconds_almost_equal(remaining, self.next_minute, 1)
-
-    # 31st of May 2010 is a monday
-    @patch_crontab_nowfun(hourly, datetime(2010, 5, 31, 10, 30))
-    def test_execution_is_due_on_monday(self):
-        last_ran = self.now - timedelta(seconds=61)
-        due, remaining = every_minute.run_every.is_due(last_ran)
-        self.assertTrue(due)
-        self.seconds_almost_equal(remaining, self.next_minute, 1)
-
-    @patch_crontab_nowfun(hourly, datetime(2010, 5, 10, 10, 30))
-    def test_every_hour_execution_is_due(self):
-        due, remaining = hourly.run_every.is_due(
-            datetime(2010, 5, 10, 6, 30))
-        self.assertTrue(due)
-        self.assertEqual(remaining, 60 * 60)
-
-    @patch_crontab_nowfun(hourly, datetime(2010, 5, 10, 10, 29))
-    def test_every_hour_execution_is_not_due(self):
-        due, remaining = hourly.run_every.is_due(
-            datetime(2010, 5, 10, 9, 30))
-        self.assertFalse(due)
-        self.assertEqual(remaining, 60)
-
-    @patch_crontab_nowfun(quarterly, datetime(2010, 5, 10, 10, 15))
-    def test_first_quarter_execution_is_due(self):
-        due, remaining = quarterly.run_every.is_due(
-            datetime(2010, 5, 10, 6, 30))
-        self.assertTrue(due)
-        self.assertEqual(remaining, 15 * 60)
-
-    @patch_crontab_nowfun(quarterly, datetime(2010, 5, 10, 10, 30))
-    def test_second_quarter_execution_is_due(self):
-        due, remaining = quarterly.run_every.is_due(
-            datetime(2010, 5, 10, 6, 30))
-        self.assertTrue(due)
-        self.assertEqual(remaining, 15 * 60)
-
-    @patch_crontab_nowfun(quarterly, datetime(2010, 5, 10, 10, 14))
-    def test_first_quarter_execution_is_not_due(self):
-        due, remaining = quarterly.run_every.is_due(
-            datetime(2010, 5, 10, 10, 0))
-        self.assertFalse(due)
-        self.assertEqual(remaining, 60)
-
-    @patch_crontab_nowfun(quarterly, datetime(2010, 5, 10, 10, 29))
-    def test_second_quarter_execution_is_not_due(self):
-        due, remaining = quarterly.run_every.is_due(
-            datetime(2010, 5, 10, 10, 15))
-        self.assertFalse(due)
-        self.assertEqual(remaining, 60)
-
-    @patch_crontab_nowfun(daily, datetime(2010, 5, 10, 7, 30))
-    def test_daily_execution_is_due(self):
-        due, remaining = daily.run_every.is_due(
-            datetime(2010, 5, 9, 7, 30))
-        self.assertTrue(due)
-        self.assertEqual(remaining, 24 * 60 * 60)
-
-    @patch_crontab_nowfun(daily, datetime(2010, 5, 10, 10, 30))
-    def test_daily_execution_is_not_due(self):
-        due, remaining = daily.run_every.is_due(
-            datetime(2010, 5, 10, 7, 30))
-        self.assertFalse(due)
-        self.assertEqual(remaining, 21 * 60 * 60)
-
-    @patch_crontab_nowfun(weekly, datetime(2010, 5, 6, 7, 30))
-    def test_weekly_execution_is_due(self):
-        due, remaining = weekly.run_every.is_due(
-            datetime(2010, 4, 30, 7, 30))
-        self.assertTrue(due)
-        self.assertEqual(remaining, 7 * 24 * 60 * 60)
-
-    @patch_crontab_nowfun(weekly, datetime(2010, 5, 7, 10, 30))
-    def test_weekly_execution_is_not_due(self):
-        due, remaining = weekly.run_every.is_due(
-            datetime(2010, 5, 6, 7, 30))
-        self.assertFalse(due)
-        self.assertEqual(remaining, 6 * 24 * 60 * 60 - 3 * 60 * 60)
-
-    @patch_crontab_nowfun(monthly, datetime(2010, 5, 13, 7, 30))
-    def test_monthly_execution_is_due(self):
-        due, remaining = monthly.run_every.is_due(
-            datetime(2010, 4, 8, 7, 30))
-        self.assertTrue(due)
-        self.assertEqual(remaining, 28 * 24 * 60 * 60)
-
-    @patch_crontab_nowfun(monthly, datetime(2010, 5, 9, 10, 30))
-    def test_monthly_execution_is_not_due(self):
-        due, remaining = monthly.run_every.is_due(
-            datetime(2010, 4, 8, 7, 30))
-        self.assertFalse(due)
-        self.assertEqual(remaining, 4 * 24 * 60 * 60 - 3 * 60 * 60)
-
-    @patch_crontab_nowfun(monthly_moy, datetime(2014, 2, 26, 22, 0))
-    def test_monthly_moy_execution_is_due(self):
-        due, remaining = monthly_moy.run_every.is_due(
-            datetime(2013, 7, 4, 10, 0))
-        self.assertTrue(due)
-        self.assertEqual(remaining, 60.)
-
-    @patch_crontab_nowfun(monthly_moy, datetime(2013, 6, 28, 14, 30))
-    def test_monthly_moy_execution_is_not_due(self):
-        raise SkipTest('unstable test')
-        due, remaining = monthly_moy.run_every.is_due(
-            datetime(2013, 6, 28, 22, 14))
-        self.assertFalse(due)
-        attempt = (
-            time.mktime(datetime(2014, 2, 26, 22, 0).timetuple()) -
-            time.mktime(datetime(2013, 6, 28, 14, 30).timetuple()) -
-            60 * 60
-        )
-        self.assertEqual(remaining, attempt)
-
-    @patch_crontab_nowfun(monthly_moy, datetime(2014, 2, 26, 22, 0))
-    def test_monthly_moy_execution_is_due2(self):
-        due, remaining = monthly_moy.run_every.is_due(
-            datetime(2013, 2, 28, 10, 0))
-        self.assertTrue(due)
-        self.assertEqual(remaining, 60.)
-
-    @patch_crontab_nowfun(monthly_moy, datetime(2014, 2, 26, 21, 0))
-    def test_monthly_moy_execution_is_not_due2(self):
-        due, remaining = monthly_moy.run_every.is_due(
-            datetime(2013, 6, 28, 22, 14))
-        self.assertFalse(due)
-        attempt = 60 * 60
-        self.assertEqual(remaining, attempt)
-
-    @patch_crontab_nowfun(yearly, datetime(2010, 3, 11, 7, 30))
-    def test_yearly_execution_is_due(self):
-        due, remaining = yearly.run_every.is_due(
-            datetime(2009, 3, 12, 7, 30))
-        self.assertTrue(due)
-        self.assertEqual(remaining, 364 * 24 * 60 * 60)
-
-    @patch_crontab_nowfun(yearly, datetime(2010, 3, 7, 10, 30))
-    def test_yearly_execution_is_not_due(self):
-        due, remaining = yearly.run_every.is_due(
-            datetime(2009, 3, 12, 7, 30))
-        self.assertFalse(due)
-        self.assertEqual(remaining, 4 * 24 * 60 * 60 - 3 * 60 * 60)

+ 29 - 26
celery/tests/tasks/test_trace.py

@@ -16,60 +16,63 @@ from celery.app.trace import (
 from celery.tests.case import AppCase
 
 
-def trace(task, args=(), kwargs={}, propagate=False, **opts):
+def trace(app, task, args=(), kwargs={}, propagate=False, **opts):
     return eager_trace_task(task, 'id-1', args, kwargs,
-                            propagate=propagate, **opts)
+                            propagate=propagate, app=app, **opts)
 
 
 class TraceCase(AppCase):
 
     def setup(self):
-        @self.app.task
+        @self.app.task(shared=False)
         def add(x, y):
             return x + y
         self.add = add
 
-        @self.app.task(ignore_result=True)
+        @self.app.task(shared=False, ignore_result=True)
         def add_cast(x, y):
             return x + y
         self.add_cast = add_cast
 
-        @self.app.task
+        @self.app.task(shared=False)
         def raises(exc):
             raise exc
         self.raises = raises
 
+    def trace(self, *args, **kwargs):
+        return trace(self.app, *args, **kwargs)
+
 
 class test_trace(TraceCase):
 
     def test_trace_successful(self):
-        retval, info = trace(self.add, (2, 2), {})
+        retval, info = self.trace(self.add, (2, 2), {})
         self.assertIsNone(info)
         self.assertEqual(retval, 4)
 
     def test_trace_on_success(self):
 
-        @self.app.task(on_success=Mock())
+        @self.app.task(shared=False, on_success=Mock())
         def add_with_success(x, y):
             return x + y
 
-        trace(add_with_success, (2, 2), {})
+        self.trace(add_with_success, (2, 2), {})
         self.assertTrue(add_with_success.on_success.called)
 
     def test_trace_after_return(self):
 
-        @self.app.task(after_return=Mock())
+        @self.app.task(shared=False, after_return=Mock())
         def add_with_after_return(x, y):
             return x + y
 
-        trace(add_with_after_return, (2, 2), {})
+        self.trace(add_with_after_return, (2, 2), {})
         self.assertTrue(add_with_after_return.after_return.called)
 
     def test_with_prerun_receivers(self):
         on_prerun = Mock()
         signals.task_prerun.connect(on_prerun)
         try:
-            trace(self.add, (2, 2), {})
+            self.trace(self.add, (2, 2), {})
             self.assertTrue(on_prerun.called)
         finally:
             signals.task_prerun.receivers[:] = []
@@ -78,7 +81,7 @@ class test_trace(TraceCase):
         on_postrun = Mock()
         signals.task_postrun.connect(on_postrun)
         try:
-            trace(self.add, (2, 2), {})
+            self.trace(self.add, (2, 2), {})
             self.assertTrue(on_postrun.called)
         finally:
             signals.task_postrun.receivers[:] = []
@@ -87,62 +90,62 @@ class test_trace(TraceCase):
         on_success = Mock()
         signals.task_success.connect(on_success)
         try:
-            trace(self.add, (2, 2), {})
+            self.trace(self.add, (2, 2), {})
             self.assertTrue(on_success.called)
         finally:
             signals.task_success.receivers[:] = []
 
     def test_when_chord_part(self):
 
-        @self.app.task
+        @self.app.task(shared=False)
         def add(x, y):
             return x + y
         add.backend = Mock()
 
-        trace(add, (2, 2), {}, request={'chord': uuid()})
+        self.trace(add, (2, 2), {}, request={'chord': uuid()})
         add.backend.on_chord_part_return.assert_called_with(add)
 
     def test_when_backend_cleanup_raises(self):
 
-        @self.app.task
+        @self.app.task(shared=False)
         def add(x, y):
             return x + y
         add.backend = Mock(name='backend')
         add.backend.process_cleanup.side_effect = KeyError()
-        trace(add, (2, 2), {}, eager=False)
+        self.trace(add, (2, 2), {}, eager=False)
         add.backend.process_cleanup.assert_called_with()
         add.backend.process_cleanup.side_effect = MemoryError()
         with self.assertRaises(MemoryError):
-            trace(add, (2, 2), {}, eager=False)
+            self.trace(add, (2, 2), {}, eager=False)
 
     def test_when_Ignore(self):
 
-        @self.app.task
+        @self.app.task(shared=False)
         def ignored():
             raise Ignore()
 
-        retval, info = trace(ignored, (), {})
+        retval, info = self.trace(ignored, (), {})
         self.assertEqual(info.state, states.IGNORED)
 
     def test_trace_SystemExit(self):
         with self.assertRaises(SystemExit):
-            trace(self.raises, (SystemExit(), ), {})
+            self.trace(self.raises, (SystemExit(), ), {})
 
     def test_trace_RetryTaskError(self):
         exc = RetryTaskError('foo', 'bar')
-        _, info = trace(self.raises, (exc, ), {})
+        _, info = self.trace(self.raises, (exc, ), {})
         self.assertEqual(info.state, states.RETRY)
         self.assertIs(info.retval, exc)
 
     def test_trace_exception(self):
         exc = KeyError('foo')
-        _, info = trace(self.raises, (exc, ), {})
+        _, info = self.trace(self.raises, (exc, ), {})
         self.assertEqual(info.state, states.FAILURE)
         self.assertIs(info.retval, exc)
 
     def test_trace_exception_propagate(self):
         with self.assertRaises(KeyError):
-            trace(self.raises, (KeyError('foo'), ), {}, propagate=True)
+            self.trace(self.raises, (KeyError('foo'), ), {}, propagate=True)
 
     @patch('celery.app.trace.build_tracer')
     @patch('celery.app.trace.report_internal_error')
@@ -151,7 +154,7 @@ class test_trace(TraceCase):
         tracer.side_effect = KeyError('foo')
         build_tracer.return_value = tracer
 
-        @self.app.task
+        @self.app.task(shared=False)
         def xtask():
             pass
 
@@ -180,7 +183,7 @@ class test_stackprotection(AppCase):
     def test_stackprotection(self):
         setup_worker_optimizations(self.app)
         try:
-            @self.app.task(bind=True)
+            @self.app.task(shared=False, bind=True)
             def foo(self, i):
                 if i:
                     return foo(0)

+ 7 - 8
celery/tests/utils/test_dispatcher.py

@@ -55,7 +55,7 @@ class DispatcherTests(Case):
         # force cleanup just in case
         signal.receivers = []
 
-    def testExact(self):
+    def test_exact(self):
         a_signal.connect(receiver_1_arg, sender=self)
         expected = [(receiver_1_arg, 'test')]
         result = a_signal.send(sender=self, val='test')
@@ -63,7 +63,7 @@ class DispatcherTests(Case):
         a_signal.disconnect(receiver_1_arg, sender=self)
         self._testIsClean(a_signal)
 
-    def testIgnoredSender(self):
+    def test_ignored_sender(self):
         a_signal.connect(receiver_1_arg)
         expected = [(receiver_1_arg, 'test')]
         result = a_signal.send(sender=self, val='test')
@@ -71,7 +71,7 @@ class DispatcherTests(Case):
         a_signal.disconnect(receiver_1_arg)
         self._testIsClean(a_signal)
 
-    def testGarbageCollected(self):
+    def test_garbage_collected(self):
         a = Callable()
         a_signal.connect(a.a, sender=self)
         expected = []
@@ -81,7 +81,7 @@ class DispatcherTests(Case):
         self.assertEqual(result, expected)
         self._testIsClean(a_signal)
 
-    def testMultipleRegistration(self):
+    def test_multiple_registration(self):
         a = Callable()
         a_signal.connect(a)
         a_signal.connect(a)
@@ -97,7 +97,7 @@ class DispatcherTests(Case):
         garbage_collect()
         self._testIsClean(a_signal)
 
-    def testUidRegistration(self):
+    def test_uid_registration(self):
 
         def uid_based_receiver_1(**kwargs):
             pass
@@ -111,8 +111,7 @@ class DispatcherTests(Case):
         a_signal.disconnect(dispatch_uid='uid')
         self._testIsClean(a_signal)
 
-    def testRobust(self):
-        """Test the sendRobust function"""
+    def test_robust(self):
 
         def fails(val, **kwargs):
             raise ValueError('this')
@@ -125,7 +124,7 @@ class DispatcherTests(Case):
         a_signal.disconnect(fails)
         self._testIsClean(a_signal)
 
-    def testDisconnection(self):
+    def test_disconnection(self):
         receiver_1 = Callable()
         receiver_2 = Callable()
         receiver_3 = Callable()

+ 22 - 8
celery/tests/utils/test_saferef.py

@@ -46,18 +46,30 @@ class SaferefTests(Case):
         del self.ts
         del self.ss
 
-    def testIn(self):
-        """Test the "in" operator for safe references (cmp)"""
+    def test_in(self):
+        """test_in
+
+        Test the "in" operator for safe references (cmp)
+
+        """
         for t in self.ts[:50]:
             self.assertTrue(safe_ref(t.x) in self.ss)
 
-    def testValid(self):
-        """Test that the references are valid (return instance methods)"""
+    def test_valid(self):
+        """test_value
+
+        Test that the references are valid (return instance methods)
+
+        """
         for s in self.ss:
             self.assertTrue(s())
 
-    def testShortCircuit(self):
-        """Test that creation short-circuits to reuse existing references"""
+    def test_shortcircuit(self):
+        """test_shortcircuit
+
+        Test that creation short-circuits to reuse existing references
+
+        """
         sd = {}
         for s in self.ss:
             sd[s] = 1
@@ -67,8 +79,10 @@ class SaferefTests(Case):
             else:
                 self.assertIn(safe_ref(t), sd)
 
-    def testRepresentation(self):
-        """Test that the reference object's representation works
+    def test_representation(self):
+        """test_representation
+
+        Test that the reference object's representation works
 
         XXX Doesn't currently check the results, just that no error
             is raised

+ 4 - 6
celery/tests/utils/test_text.py

@@ -1,6 +1,5 @@
 from __future__ import absolute_import
 
-from celery import Celery
 from celery.utils.text import (
     indent,
     ensure_2lines,
@@ -9,7 +8,7 @@ from celery.utils.text import (
     abbrtask,
     pretty,
 )
-from celery.tests.case import Case
+from celery.tests.case import AppCase, Case
 
 RANDTEXT = """\
 The quick brown
@@ -43,15 +42,14 @@ QUEUE_FORMAT1 = '.> queue1           exchange=exchange1(type1) key=bind1'
 QUEUE_FORMAT2 = '.> queue2           exchange=exchange2(type2) key=bind2'
 
 
-class test_Info(Case):
+class test_Info(AppCase):
 
     def test_textindent(self):
         self.assertEqual(indent(RANDTEXT, 4), RANDTEXT_RES)
 
     def test_format_queues(self):
-        celery = Celery(set_as_current=False)
-        celery.amqp.queues = celery.amqp.Queues(QUEUES)
-        self.assertEqual(sorted(celery.amqp.queues.format().split('\n')),
+        self.app.amqp.queues = self.app.amqp.Queues(QUEUES)
+        self.assertEqual(sorted(self.app.amqp.queues.format().split('\n')),
                          sorted([QUEUE_FORMAT1, QUEUE_FORMAT2]))
 
     def test_ensure_2lines(self):

+ 6 - 6
celery/tests/worker/test_bootsteps.py

@@ -4,10 +4,10 @@ from mock import Mock, patch
 
 from celery import bootsteps
 
-from celery.tests.case import AppCase, Case
+from celery.tests.case import AppCase
 
 
-class test_StepFormatter(Case):
+class test_StepFormatter(AppCase):
 
     def test_get_prefix(self):
         f = bootsteps.StepFormatter()
@@ -53,12 +53,12 @@ class test_StepFormatter(Case):
         })
 
 
-class test_Step(Case):
+class test_Step(AppCase):
 
     class Def(bootsteps.StartStopStep):
         name = 'test_Step.Def'
 
-    def setUp(self):
+    def setup(self):
         self.steps = []
 
     def test_blueprint_name(self, bp='test_blueprint_name'):
@@ -151,12 +151,12 @@ class test_ConsumerStep(AppCase):
         step.start(self)
 
 
-class test_StartStopStep(Case):
+class test_StartStopStep(AppCase):
 
     class Def(bootsteps.StartStopStep):
         name = 'test_StartStopStep.Def'
 
-    def setUp(self):
+    def setup(self):
         self.steps = []
 
     def test_start__stop(self):

+ 14 - 27
celery/tests/worker/test_consumer.py

@@ -59,26 +59,16 @@ class test_Consumer(AppCase):
     def test_sets_heartbeat(self):
         c = self.get_consumer(amqheartbeat=10)
         self.assertEqual(c.amqheartbeat, 10)
-        prev, self.app.conf.BROKER_HEARTBEAT = (
-            self.app.conf.BROKER_HEARTBEAT, 20,
-        )
-        try:
-            c = self.get_consumer(amqheartbeat=None)
-            self.assertEqual(c.amqheartbeat, 20)
-        finally:
-            self.app.conf.BROKER_HEARTBEAT = prev
+        self.app.conf.BROKER_HEARTBEAT = 20
+        c = self.get_consumer(amqheartbeat=None)
+        self.assertEqual(c.amqheartbeat, 20)
 
     def test_gevent_bug_disables_connection_timeout(self):
         with patch('celery.worker.consumer._detect_environment') as de:
             de.return_value = 'gevent'
-            prev, self.app.conf.BROKER_CONNECTION_TIMEOUT = (
-                self.app.conf.BROKER_CONNECTION_TIMEOUT, 33.33,
-            )
-            try:
-                self.get_consumer()
-                self.assertIsNone(self.app.conf.BROKER_CONNECTION_TIMEOUT)
-            finally:
-                self.app.conf.BROKER_CONNECTION_TIMEOUT = prev
+            self.app.conf.BROKER_CONNECTION_TIMEOUT = 33.33
+            self.get_consumer()
+            self.assertIsNone(self.app.conf.BROKER_CONNECTION_TIMEOUT)
 
     def test_limit_task(self):
         c = self.get_consumer()
@@ -163,17 +153,14 @@ class test_Consumer(AppCase):
             c.on_close()
 
     def test_connect_error_handler(self):
-        _prev, self.app.connection = self.app.connection, Mock()
-        try:
-            conn = self.app.connection.return_value = Mock()
-            c = self.get_consumer()
-            self.assertTrue(c.connect())
-            self.assertTrue(conn.ensure_connection.called)
-            errback = conn.ensure_connection.call_args[0][0]
-            conn.alt = [(1, 2, 3)]
-            errback(Mock(), 0)
-        finally:
-            self.app.connection = _prev
+        self.app.connection = Mock()
+        conn = self.app.connection.return_value = Mock()
+        c = self.get_consumer()
+        self.assertTrue(c.connect())
+        self.assertTrue(conn.ensure_connection.called)
+        errback = conn.ensure_connection.call_args[0][0]
+        conn.alt = [(1, 2, 3)]
+        errback(Mock(), 0)
 
 
 class test_Heart(AppCase):

+ 59 - 76
celery/tests/worker/test_control.py

@@ -198,28 +198,24 @@ class test_ControlPanel(AppCase):
 
     def test_time_limit(self):
         panel = self.create_panel(consumer=Mock())
-        th, ts = self.mytask.time_limit, self.mytask.soft_time_limit
-        try:
-            r = panel.handle('time_limit', arguments=dict(
-                task_name=self.mytask.name, hard=30, soft=10))
-            self.assertEqual(
-                (self.mytask.time_limit, self.mytask.soft_time_limit),
-                (30, 10),
-            )
-            self.assertIn('ok', r)
-            r = panel.handle('time_limit', arguments=dict(
-                task_name=self.mytask.name, hard=None, soft=None))
-            self.assertEqual(
-                (self.mytask.time_limit, self.mytask.soft_time_limit),
-                (None, None),
-            )
-            self.assertIn('ok', r)
+        r = panel.handle('time_limit', arguments=dict(
+            task_name=self.mytask.name, hard=30, soft=10))
+        self.assertEqual(
+            (self.mytask.time_limit, self.mytask.soft_time_limit),
+            (30, 10),
+        )
+        self.assertIn('ok', r)
+        r = panel.handle('time_limit', arguments=dict(
+            task_name=self.mytask.name, hard=None, soft=None))
+        self.assertEqual(
+            (self.mytask.time_limit, self.mytask.soft_time_limit),
+            (None, None),
+        )
+        self.assertIn('ok', r)
 
-            r = panel.handle('time_limit', arguments=dict(
-                task_name='248e8afya9s8dh921eh928', hard=30))
-            self.assertIn('error', r)
-        finally:
-            self.time_limit, self.soft_time_limit = th, ts
+        r = panel.handle('time_limit', arguments=dict(
+            task_name='248e8afya9s8dh921eh928', hard=30))
+        self.assertIn('error', r)
 
     def test_active_queues(self):
         import kombu
@@ -381,19 +377,15 @@ class test_ControlPanel(AppCase):
         panel = self.create_panel(app=self.app, consumer=consumer)
 
         task = self.app.tasks[self.mytask.name]
-        old_rate_limit = task.rate_limit
-        try:
-            panel.handle('rate_limit', arguments=dict(task_name=task.name,
-                                                      rate_limit='100/m'))
-            self.assertEqual(task.rate_limit, '100/m')
-            self.assertTrue(consumer.reset)
-            consumer.reset = False
-            panel.handle('rate_limit', arguments=dict(task_name=task.name,
-                                                      rate_limit=0))
-            self.assertEqual(task.rate_limit, 0)
-            self.assertTrue(consumer.reset)
-        finally:
-            task.rate_limit = old_rate_limit
+        panel.handle('rate_limit', arguments=dict(task_name=task.name,
+                                                  rate_limit='100/m'))
+        self.assertEqual(task.rate_limit, '100/m')
+        self.assertTrue(consumer.reset)
+        consumer.reset = False
+        panel.handle('rate_limit', arguments=dict(task_name=task.name,
+                                                  rate_limit=0))
+        self.assertEqual(task.rate_limit, 0)
+        self.assertTrue(consumer.reset)
 
     def test_rate_limit_nonexistant_task(self):
         self.panel.handle('rate_limit', arguments={
@@ -509,13 +501,10 @@ class test_ControlPanel(AppCase):
             panel.handle('pool_restart', {'reloader': _reload})
 
         self.app.conf.CELERYD_POOL_RESTARTS = True
-        try:
-            panel.handle('pool_restart', {'reloader': _reload})
-            self.assertTrue(consumer.controller.pool.restart.called)
-            self.assertFalse(_reload.called)
-            self.assertFalse(_import.called)
-        finally:
-            self.app.conf.CELERYD_POOL_RESTARTS = False
+        panel.handle('pool_restart', {'reloader': _reload})
+        self.assertTrue(consumer.controller.pool.restart.called)
+        self.assertFalse(_reload.called)
+        self.assertFalse(_import.called)
 
     def test_pool_restart_import_modules(self):
         consumer = Consumer(self.app)
@@ -527,18 +516,15 @@ class test_ControlPanel(AppCase):
         _reload = Mock()
 
         self.app.conf.CELERYD_POOL_RESTARTS = True
-        try:
-            panel.handle('pool_restart', {'modules': ['foo', 'bar'],
-                                          'reloader': _reload})
-
-            self.assertTrue(consumer.controller.pool.restart.called)
-            self.assertFalse(_reload.called)
-            self.assertItemsEqual(
-                [call('bar'), call('foo')],
-                _import.call_args_list,
-            )
-        finally:
-            self.app.conf.CELERYD_POOL_RESTARTS = False
+        panel.handle('pool_restart', {'modules': ['foo', 'bar'],
+                                      'reloader': _reload})
+
+        self.assertTrue(consumer.controller.pool.restart.called)
+        self.assertFalse(_reload.called)
+        self.assertItemsEqual(
+            [call('bar'), call('foo')],
+            _import.call_args_list,
+        )
 
     def test_pool_restart_reload_modules(self):
         consumer = Consumer(self.app)
@@ -550,26 +536,23 @@ class test_ControlPanel(AppCase):
         _reload = Mock()
 
         self.app.conf.CELERYD_POOL_RESTARTS = True
-        try:
-            with patch.dict(sys.modules, {'foo': None}):
-                panel.handle('pool_restart', {'modules': ['foo'],
-                                              'reload': False,
-                                              'reloader': _reload})
-
-                self.assertTrue(consumer.controller.pool.restart.called)
-                self.assertFalse(_reload.called)
-                self.assertFalse(_import.called)
-
-                _import.reset_mock()
-                _reload.reset_mock()
-                consumer.controller.pool.restart.reset_mock()
-
-                panel.handle('pool_restart', {'modules': ['foo'],
-                                              'reload': True,
-                                              'reloader': _reload})
-
-                self.assertTrue(consumer.controller.pool.restart.called)
-                self.assertTrue(_reload.called)
-                self.assertFalse(_import.called)
-        finally:
-            self.app.conf.CELERYD_POOL_RESTARTS = False
+        with patch.dict(sys.modules, {'foo': None}):
+            panel.handle('pool_restart', {'modules': ['foo'],
+                                          'reload': False,
+                                          'reloader': _reload})
+
+            self.assertTrue(consumer.controller.pool.restart.called)
+            self.assertFalse(_reload.called)
+            self.assertFalse(_import.called)
+
+            _import.reset_mock()
+            _reload.reset_mock()
+            consumer.controller.pool.restart.reset_mock()
+
+            panel.handle('pool_restart', {'modules': ['foo'],
+                                          'reload': True,
+                                          'reloader': _reload})
+
+            self.assertTrue(consumer.controller.pool.restart.called)
+            self.assertTrue(_reload.called)
+            self.assertFalse(_import.called)

+ 2 - 2
celery/tests/worker/test_heartbeat.py

@@ -1,7 +1,7 @@
 from __future__ import absolute_import
 
 from celery.worker.heartbeat import Heart
-from celery.tests.case import Case
+from celery.tests.case import AppCase
 
 
 class MockDispatcher(object):
@@ -45,7 +45,7 @@ class MockTimer(object):
         entry.cancel()
 
 
-class test_Heart(Case):
+class test_Heart(AppCase):
 
     def test_start_stop(self):
         timer = MockTimer()

+ 1 - 1
celery/tests/worker/test_loops.py

@@ -100,7 +100,7 @@ class test_asynloop(AppCase):
 
     def setup(self):
 
-        @self.app.task()
+        @self.app.task(shared=False)
         def add(x, y):
             return x + y
         self.add = add

+ 27 - 30
celery/tests/worker/test_request.py

@@ -84,7 +84,7 @@ def jail(app, task_id, name, args, kwargs):
     task = app.tasks[name]
     task.__trace__ = None  # rebuild
     return trace_task(
-        task, task_id, args, kwargs, request=request, eager=False,
+        task, task_id, args, kwargs, request=request, eager=False, app=app,
     )
 
 
@@ -161,25 +161,23 @@ class test_trace_task(AppCase):
         self.assertEqual(ret, 4)
 
     def test_marked_as_started(self):
+        _started = []
 
-        class Backend(self.mytask.backend.__class__):
-            _started = []
-
-            def store_result(self, tid, meta, state):
-                if state == states.STARTED:
-                    self._started.append(tid)
-
-        self.mytask.backend = Backend(self.app)
+        def store_result(tid, meta, state):
+            if state == states.STARTED:
+                _started.append(tid)
+        self.mytask.backend.store_result = Mock(name='store_result')
+        self.mytask.backend.store_result.side_effect = store_result
         self.mytask.track_started = True
 
         tid = uuid()
         jail(self.app, tid, self.mytask.name, [2], {})
-        self.assertIn(tid, Backend._started)
+        self.assertIn(tid, _started)
 
         self.mytask.ignore_result = True
         tid = uuid()
         jail(self.app, tid, self.mytask.name, [2], {})
-        self.assertNotIn(tid, Backend._started)
+        self.assertNotIn(tid, _started)
 
     def test_execute_jail_failure(self):
         ret = jail(
@@ -309,18 +307,15 @@ class test_Request(AppCase):
         task.freeze()
         req = self.get_request(task)
         self.add.accept_magic_kwargs = True
-        try:
-            pool = Mock()
-            req.execute_using_pool(pool)
-            self.assertTrue(pool.apply_async.called)
-            args = pool.apply_async.call_args[1]['args']
-            self.assertEqual(args[0], task.task)
-            self.assertEqual(args[1], task.id)
-            self.assertEqual(args[2], task.args)
-            kwargs = args[3]
-            self.assertEqual(kwargs.get('task_name'), task.task)
-        finally:
-            self.add.accept_magic_kwargs = False
+        pool = Mock()
+        req.execute_using_pool(pool)
+        self.assertTrue(pool.apply_async.called)
+        args = pool.apply_async.call_args[1]['args']
+        self.assertEqual(args[0], task.task)
+        self.assertEqual(args[1], task.id)
+        self.assertEqual(args[2], task.args)
+        kwargs = args[3]
+        self.assertEqual(kwargs.get('task_name'), task.task)
 
     def test_task_wrapper_repr(self):
         job = TaskRequest(
@@ -697,6 +692,7 @@ class test_Request(AppCase):
         try:
             self.mytask.__trace__ = build_tracer(
                 self.mytask.name, self.mytask, self.app.loader, 'test',
+                app=self.app,
             )
             res = trace.trace_task_ret(self.mytask.name, uuid(), [4], {})
             self.assertEqual(res, 4 ** 4)
@@ -704,24 +700,25 @@ class test_Request(AppCase):
             reset_worker_optimizations()
             self.assertIs(trace.trace_task_ret, trace._trace_task_ret)
         delattr(self.mytask, '__trace__')
-        res = trace.trace_task_ret(self.mytask.name, uuid(), [4], {})
+        res = trace.trace_task_ret(
+            self.mytask.name, uuid(), [4], {}, app=self.app,
+        )
         self.assertEqual(res, 4 ** 4)
 
     def test_trace_task_ret(self):
-        self.app.set_current()   # XXX compat test
         self.mytask.__trace__ = build_tracer(
             self.mytask.name, self.mytask, self.app.loader, 'test',
+            app=self.app,
         )
-        res = _trace_task_ret(self.mytask.name, uuid(), [4], {})
+        res = _trace_task_ret(self.mytask.name, uuid(), [4], {}, app=self.app)
         self.assertEqual(res, 4 ** 4)
 
     def test_trace_task_ret__no_trace(self):
-        self.app.set_current()  # XXX compat test
         try:
             delattr(self.mytask, '__trace__')
         except AttributeError:
             pass
-        res = _trace_task_ret(self.mytask.name, uuid(), [4], {})
+        res = _trace_task_ret(self.mytask.name, uuid(), [4], {}, app=self.app)
         self.assertEqual(res, 4 ** 4)
 
     def test_trace_catches_exception(self):
@@ -735,7 +732,7 @@ class test_Request(AppCase):
 
         with self.assertWarnsRegex(RuntimeWarning,
                                    r'Exception raised outside'):
-            res = trace_task(raising, uuid(), [], {})
+            res = trace_task(raising, uuid(), [], {}, app=self.app)
             self.assertIsInstance(res, ExceptionInfo)
 
     def test_worker_task_trace_handle_retry(self):
@@ -865,7 +862,7 @@ class test_Request(AppCase):
     def test_execute_success_some_kwargs(self):
         scratch = {'task_id': None}
 
-        @self.app.task(accept_magic_kwargs=True)
+        @self.app.task(shared=False, accept_magic_kwargs=True)
         def mytask_some_kwargs(i, task_id):
             scratch['task_id'] = task_id
             return i ** i

+ 2 - 2
celery/tests/worker/test_revoke.py

@@ -1,10 +1,10 @@
 from __future__ import absolute_import
 
 from celery.worker import state
-from celery.tests.case import Case
+from celery.tests.case import AppCase
 
 
-class test_revoked(Case):
+class test_revoked(AppCase):
 
     def test_is_working(self):
         state.revoked.add('foo')

+ 8 - 15
celery/tests/worker/test_state.py

@@ -9,30 +9,22 @@ from celery.datastructures import LimitedSet
 from celery.exceptions import SystemTerminate
 from celery.worker import state
 
-from celery.tests.case import Case
+from celery.tests.case import AppCase
 
 
-class StateResetCase(Case):
+class StateResetCase(AppCase):
 
-    def setUp(self):
+    def setup(self):
         self.reset_state()
-        self.on_setup()
 
-    def tearDown(self):
+    def teardown(self):
         self.reset_state()
-        self.on_teardown()
 
     def reset_state(self):
         state.active_requests.clear()
         state.revoked.clear()
         state.total_count.clear()
 
-    def on_setup(self):
-        pass
-
-    def on_teardown(self):
-        pass
-
 
 class MockShelve(dict):
     filename = None
@@ -54,9 +46,9 @@ class MyPersistent(state.Persistent):
     storage = MockShelve()
 
 
-class test_maybe_shutdown(Case):
+class test_maybe_shutdown(AppCase):
 
-    def tearDown(self):
+    def teardown(self):
         state.should_stop = False
         state.should_terminate = False
 
@@ -73,7 +65,8 @@ class test_maybe_shutdown(Case):
 
 class test_Persistent(StateResetCase):
 
-    def on_setup(self):
+    def setup(self):
+        self.reset_state()
         self.p = MyPersistent(state, filename='celery-state')
 
     def test_close_twice(self):

+ 9 - 12
celery/tests/worker/test_strategy.py

@@ -6,7 +6,6 @@ from mock import Mock, patch
 
 from kombu.utils.limits import TokenBucket
 
-from celery import Celery
 from celery.worker import state
 from celery.utils.timeutils import rate
 
@@ -15,6 +14,13 @@ from celery.tests.case import AppCase, body_from_sig
 
 class test_default_strategy(AppCase):
 
+    def setup(self):
+        @self.app.task(shared=False)
+        def add(x, y):
+            return x + y
+
+        self.add = add
+
     class Context(object):
 
         def __init__(self, sig, s, reserved, consumer, message, body):
@@ -52,15 +58,6 @@ class test_default_strategy(AppCase):
                 return self.consumer.timer.apply_at.call_args[0][0]
             raise ValueError('request not handled')
 
-    def setup(self):
-        self.c = Celery(set_as_current=False)
-
-        @self.c.task()
-        def add(x, y):
-            return x + y
-
-        self.add = add
-
     @contextmanager
     def _context(self, sig,
                  rate_limits=True, events=True, utc=True, limit=None):
@@ -74,11 +71,11 @@ class test_default_strategy(AppCase):
             consumer.task_buckets[sig.task] = bucket
         consumer.disable_rate_limits = not rate_limits
         consumer.event_dispatcher.enabled = events
-        s = sig.type.start_strategy(self.c, consumer, task_reserved=reserved)
+        s = sig.type.start_strategy(self.app, consumer, task_reserved=reserved)
         self.assertTrue(s)
 
         message = Mock()
-        body = body_from_sig(self.c, sig, utc=utc)
+        body = body_from_sig(self.app, sig, utc=utc)
 
         yield self.Context(sig, s, reserved, consumer, message, body)
 

+ 60 - 171
celery/tests/worker/test_worker.py

@@ -9,7 +9,7 @@ from threading import Event
 
 from billiard.exceptions import WorkerLostError
 from kombu import Connection
-from kombu.common import QoS, PREFETCH_COUNT_MAX, ignore_errors
+from kombu.common import QoS, ignore_errors
 from kombu.exceptions import StdChannelError
 from kombu.transport.base import Message
 from mock import call, Mock, patch
@@ -20,8 +20,6 @@ from celery.concurrency.base import BasePool
 from celery.datastructures import AttributeDict
 from celery.exceptions import SystemTerminate, TaskRevokedError
 from celery.five import Empty, range, Queue as FastQueue
-from celery.task import task as task_dec
-from celery.task import periodic_task as periodic_task_dec
 from celery.utils import uuid
 from celery.worker import components
 from celery.worker import consumer
@@ -32,7 +30,7 @@ from celery.utils import worker_direct
 from celery.utils.serialization import pickle
 from celery.utils.timer2 import Timer
 
-from celery.tests.case import AppCase, Case, restore_logging
+from celery.tests.case import AppCase, restore_logging
 
 
 def MockStep(step=None):
@@ -108,16 +106,6 @@ class MockHeart(object):
         self.closed = True
 
 
-@task_dec()
-def foo_task(x, y, z, **kwargs):
-    return x * y * z
-
-
-@periodic_task_dec(run_every=60)
-def foo_periodic_task():
-    return 'foo'
-
-
 def create_message(channel, **data):
     data.setdefault('id', uuid())
     channel.no_ack_consumers = set()
@@ -129,117 +117,17 @@ def create_message(channel, **data):
     return m
 
 
-class test_QoS(Case):
-
-    class _QoS(QoS):
-        def __init__(self, value):
-            self.value = value
-            QoS.__init__(self, None, value)
-
-        def set(self, value):
-            return value
-
-    def test_qos_increment_decrement(self):
-        qos = self._QoS(10)
-        self.assertEqual(qos.increment_eventually(), 11)
-        self.assertEqual(qos.increment_eventually(3), 14)
-        self.assertEqual(qos.increment_eventually(-30), 14)
-        self.assertEqual(qos.decrement_eventually(7), 7)
-        self.assertEqual(qos.decrement_eventually(), 6)
-
-    def test_qos_disabled_increment_decrement(self):
-        qos = self._QoS(0)
-        self.assertEqual(qos.increment_eventually(), 0)
-        self.assertEqual(qos.increment_eventually(3), 0)
-        self.assertEqual(qos.increment_eventually(-30), 0)
-        self.assertEqual(qos.decrement_eventually(7), 0)
-        self.assertEqual(qos.decrement_eventually(), 0)
-        self.assertEqual(qos.decrement_eventually(10), 0)
-
-    def test_qos_thread_safe(self):
-        qos = self._QoS(10)
-
-        def add():
-            for i in range(1000):
-                qos.increment_eventually()
-
-        def sub():
-            for i in range(1000):
-                qos.decrement_eventually()
-
-        def threaded(funs):
-            from threading import Thread
-            threads = [Thread(target=fun) for fun in funs]
-            for thread in threads:
-                thread.start()
-            for thread in threads:
-                thread.join()
-
-        threaded([add, add])
-        self.assertEqual(qos.value, 2010)
-
-        qos.value = 1000
-        threaded([add, sub])  # n = 2
-        self.assertEqual(qos.value, 1000)
-
-    def test_exceeds_short(self):
-        qos = QoS(Mock(), PREFETCH_COUNT_MAX - 1)
-        qos.update()
-        self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1)
-        qos.increment_eventually()
-        self.assertEqual(qos.value, PREFETCH_COUNT_MAX)
-        qos.increment_eventually()
-        self.assertEqual(qos.value, PREFETCH_COUNT_MAX + 1)
-        qos.decrement_eventually()
-        self.assertEqual(qos.value, PREFETCH_COUNT_MAX)
-        qos.decrement_eventually()
-        self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1)
-
-    def test_consumer_increment_decrement(self):
-        mconsumer = Mock()
-        qos = QoS(mconsumer.qos, 10)
-        qos.update()
-        self.assertEqual(qos.value, 10)
-        mconsumer.qos.assert_called_with(prefetch_count=10)
-        qos.decrement_eventually()
-        qos.update()
-        self.assertEqual(qos.value, 9)
-        mconsumer.qos.assert_called_with(prefetch_count=9)
-        qos.decrement_eventually()
-        self.assertEqual(qos.value, 8)
-        mconsumer.qos.assert_called_with(prefetch_count=9)
-        self.assertIn({'prefetch_count': 9}, mconsumer.qos.call_args)
-
-        # Does not decrement 0 value
-        qos.value = 0
-        qos.decrement_eventually()
-        self.assertEqual(qos.value, 0)
-        qos.increment_eventually()
-        self.assertEqual(qos.value, 0)
-
-    def test_consumer_decrement_eventually(self):
-        mconsumer = Mock()
-        qos = QoS(mconsumer.qos, 10)
-        qos.decrement_eventually()
-        self.assertEqual(qos.value, 9)
-        qos.value = 0
-        qos.decrement_eventually()
-        self.assertEqual(qos.value, 0)
-
-    def test_set(self):
-        mconsumer = Mock()
-        qos = QoS(mconsumer.qos, 10)
-        qos.set(12)
-        self.assertEqual(qos.prev, 12)
-        qos.set(qos.prev)
-
-
 class test_Consumer(AppCase):
 
     def setup(self):
         self.buffer = FastQueue()
         self.timer = Timer()
 
+        @self.app.task(shared=False)
+        def foo_task(x, y, z):
+            return x * y * z
+        self.foo_task = foo_task
+
     def teardown(self):
         self.timer.stop()
 
@@ -326,7 +214,7 @@ class test_Consumer(AppCase):
         to_timestamp.side_effect = OverflowError()
         l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.steps.pop()
-        m = create_message(Mock(), task=foo_task.name,
+        m = create_message(Mock(), task=self.foo_task.name,
                            args=('2, 2'),
                            kwargs={},
                            eta=datetime.now().isoformat())
@@ -344,7 +232,7 @@ class test_Consumer(AppCase):
         l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.event_dispatcher = Mock()
         l.steps.pop()
-        m = create_message(Mock(), task=foo_task.name,
+        m = create_message(Mock(), task=self.foo_task.name,
                            args=(1, 2), kwargs='foobarbaz', id=1)
         l.update_strategies()
         l.event_dispatcher = Mock()
@@ -383,7 +271,7 @@ class test_Consumer(AppCase):
     def test_receieve_message(self):
         l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
         l.event_dispatcher = Mock()
-        m = create_message(Mock(), task=foo_task.name,
+        m = create_message(Mock(), task=self.foo_task.name,
                            args=[2, 4, 8], kwargs={})
         l.update_strategies()
         callback = self._get_on_message(l)
@@ -391,7 +279,7 @@ class test_Consumer(AppCase):
 
         in_bucket = self.buffer.get_nowait()
         self.assertIsInstance(in_bucket, Request)
-        self.assertEqual(in_bucket.name, foo_task.name)
+        self.assertEqual(in_bucket.name, self.foo_task.name)
         self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
         self.assertTrue(self.timer.empty())
 
@@ -520,7 +408,7 @@ class test_Consumer(AppCase):
         l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.steps.pop()
         m = create_message(
-            Mock(), task=foo_task.name,
+            Mock(), task=self.foo_task.name,
             eta=(datetime.now() + timedelta(days=1)).isoformat(),
             args=[2, 4, 8], kwargs={},
         )
@@ -539,7 +427,7 @@ class test_Consumer(AppCase):
         items = [entry[2] for entry in self.timer.queue]
         found = 0
         for item in items:
-            if item.args[0].name == foo_task.name:
+            if item.args[0].name == self.foo_task.name:
                 found = True
         self.assertTrue(found)
         self.assertGreater(l.qos.value, current_pcount)
@@ -570,7 +458,7 @@ class test_Consumer(AppCase):
         l.steps.pop()
         backend = Mock()
         id = uuid()
-        t = create_message(backend, task=foo_task.name, args=[2, 4, 8],
+        t = create_message(backend, task=self.foo_task.name, args=[2, 4, 8],
                            kwargs={}, id=id)
         from celery.worker.state import revoked
         revoked.add(id)
@@ -619,7 +507,7 @@ class test_Consumer(AppCase):
         l.event_dispatcher._outbound_buffer = deque()
         backend = Mock()
         m = create_message(
-            backend, task=foo_task.name,
+            backend, task=self.foo_task.name,
             args=[2, 4, 8], kwargs={},
             eta=(datetime.now() + timedelta(days=1)).isoformat(),
         )
@@ -627,10 +515,8 @@ class test_Consumer(AppCase):
         l.blueprint.start(l)
         p = l.app.conf.BROKER_CONNECTION_RETRY
         l.app.conf.BROKER_CONNECTION_RETRY = False
-        try:
-            l.blueprint.start(l)
-        finally:
-            l.app.conf.BROKER_CONNECTION_RETRY = p
+        l.blueprint.start(l)
+        l.app.conf.BROKER_CONNECTION_RETRY = p
         l.blueprint.restart(l)
         l.event_dispatcher = Mock()
         callback = self._get_on_message(l)
@@ -641,7 +527,7 @@ class test_Consumer(AppCase):
         eta, priority, entry = in_hold
         task = entry.args[0]
         self.assertIsInstance(task, Request)
-        self.assertEqual(task.name, foo_task.name)
+        self.assertEqual(task.name, self.foo_task.name)
         self.assertEqual(task.execute(), 2 * 4 * 8)
         with self.assertRaises(Empty):
             self.buffer.get_nowait()
@@ -823,6 +709,11 @@ class test_WorkController(AppCase):
         self.logger = worker.logger = Mock()
         self.comp_logger = components.logger = Mock()
 
+        @self.app.task(shared=False)
+        def foo_task(x, y, z):
+            return x * y * z
+        self.foo_task = foo_task
+
     def teardown(self):
         from celery import worker
         worker.logger = self._logger
@@ -838,15 +729,11 @@ class test_WorkController(AppCase):
 
     def test_setup_queues_worker_direct(self):
         self.app.conf.CELERY_WORKER_DIRECT = True
-        _qs, self.app.amqp.__dict__['queues'] = self.app.amqp.queues, Mock()
-        try:
-            self.worker.setup_queues({})
-            self.app.amqp.queues.select_add.assert_called_with(
-                worker_direct(self.worker.hostname),
-            )
-        finally:
-            self.app.amqp.queues = _qs
-            self.app.conf.CELERY_WORKER_DIRECT = False
+        self.app.amqp.__dict__['queues'] = Mock()
+        self.worker.setup_queues({})
+        self.app.amqp.queues.select_add.assert_called_with(
+            worker_direct(self.worker.hostname),
+        )
 
     def test_send_worker_shutdown(self):
         with patch('celery.signals.worker_shutdown') as ws:
@@ -881,40 +768,42 @@ class test_WorkController(AppCase):
     @patch('celery.platforms.set_mp_process_title')
     def test_process_initializer(self, set_mp_process_title, _signals):
         with restore_logging():
-            from celery import Celery
             from celery import signals
             from celery._state import _tls
-            from celery.concurrency.processes import process_initializer
-            from celery.concurrency.processes import (WORKER_SIGRESET,
-                                                      WORKER_SIGIGNORE)
+            from celery.concurrency.processes import (
+                process_initializer, WORKER_SIGRESET, WORKER_SIGIGNORE,
+            )
 
             def on_worker_process_init(**kwargs):
                 on_worker_process_init.called = True
             on_worker_process_init.called = False
             signals.worker_process_init.connect(on_worker_process_init)
 
-            loader = Mock()
-            loader.override_backends = {}
-            app = Celery(loader=loader, set_as_current=False)
-            app.loader = loader
-            app.conf = AttributeDict(DEFAULTS)
-            process_initializer(app, 'awesome.worker.com')
-            _signals.ignore.assert_any_call(*WORKER_SIGIGNORE)
-            _signals.reset.assert_any_call(*WORKER_SIGRESET)
-            self.assertTrue(app.loader.init_worker.call_count)
-            self.assertTrue(on_worker_process_init.called)
-            self.assertIs(_tls.current_app, app)
-            set_mp_process_title.assert_called_with(
-                'celeryd', hostname='awesome.worker.com',
-            )
-
-            with patch('celery.app.trace.setup_worker_optimizations') as swo:
-                os.environ['FORKED_BY_MULTIPROCESSING'] = "1"
-                try:
-                    process_initializer(app, 'luke.worker.com')
-                    swo.assert_called_with(app)
-                finally:
-                    os.environ.pop('FORKED_BY_MULTIPROCESSING', None)
+            def Loader(*args, **kwargs):
+                loader = Mock(*args, **kwargs)
+                loader.conf = {}
+                loader.override_backends = {}
+                return loader
+
+            with self.Celery(loader=Loader) as app:
+                app.conf = AttributeDict(DEFAULTS)
+                process_initializer(app, 'awesome.worker.com')
+                _signals.ignore.assert_any_call(*WORKER_SIGIGNORE)
+                _signals.reset.assert_any_call(*WORKER_SIGRESET)
+                self.assertTrue(app.loader.init_worker.call_count)
+                self.assertTrue(on_worker_process_init.called)
+                self.assertIs(_tls.current_app, app)
+                set_mp_process_title.assert_called_with(
+                    'celeryd', hostname='awesome.worker.com',
+                )
+
+                with patch('celery.app.trace.setup_worker_optimizations') as S:
+                    os.environ['FORKED_BY_MULTIPROCESSING'] = "1"
+                    try:
+                        process_initializer(app, 'luke.worker.com')
+                        S.assert_called_with(app)
+                    finally:
+                        os.environ.pop('FORKED_BY_MULTIPROCESSING', None)
 
     def test_attrs(self):
         worker = self.worker
@@ -976,7 +865,7 @@ class test_WorkController(AppCase):
         worker = self.worker
         worker.pool = Mock()
         backend = Mock()
-        m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
+        m = create_message(backend, task=self.foo_task.name, args=[4, 8, 10],
                            kwargs={})
         task = Request.from_message(m, m.decode(), app=self.app)
         worker._process_task(task)
@@ -988,7 +877,7 @@ class test_WorkController(AppCase):
         worker.pool = Mock()
         worker.pool.apply_async.side_effect = KeyboardInterrupt('Ctrl+C')
         backend = Mock()
-        m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
+        m = create_message(backend, task=self.foo_task.name, args=[4, 8, 10],
                            kwargs={})
         task = Request.from_message(m, m.decode(), app=self.app)
         worker.steps = []
@@ -1002,7 +891,7 @@ class test_WorkController(AppCase):
         worker.pool = Mock()
         worker.pool.apply_async.side_effect = SystemTerminate()
         backend = Mock()
-        m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
+        m = create_message(backend, task=self.foo_task.name, args=[4, 8, 10],
                            kwargs={})
         task = Request.from_message(m, m.decode(), app=self.app)
         worker.steps = []
@@ -1016,7 +905,7 @@ class test_WorkController(AppCase):
         worker.pool = Mock()
         worker.pool.apply_async.side_effect = KeyError('some exception')
         backend = Mock()
-        m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
+        m = create_message(backend, task=self.foo_task.name, args=[4, 8, 10],
                            kwargs={})
         task = Request.from_message(m, m.decode(), app=self.app)
         worker._process_task(task)