瀏覽代碼

Huge test refactoring so that the unit tests no longer uses the current app.

- During test runs the current_app and default_app is set to an object
  that will crash if any attribute is accessed.

    For tests that require the current_app to exist (like pickle tests and the
    tests in compat_modules) there is a now decorator that can be used to allow
    it:

        from celery.tests.case import AppCase, depends_on_current_app

        class test_Example(AppCase):

            @depends_on_current_app
            def test_pickle_result(self):
                pickle.loads(pickle.dumps(self.app.AsyncResult))

- There is no longer any default configuration set when the test suite runs.

    The test suite used to setup the Celery global environment by
    setting the CELERY_CONFIG_MODULE environment variable to the module
    ``celery.tests.config``.  This will no longer happen, so tests can no
    longer use the global default_app.   See below.

- AppCase.app is now created and destroyed for every test function.

- Unit tests should now use a custom app constructor: AppCase.Celery

    This will return an app object that is configured to be used for
    testing.  E.g. the broker is set to memory:// and the backend
    is set to cache+memory://, and any other configuration that
    makes it suitable for testing::

        from celery.tests.case import AppCase

        class test_Example(AppCase):

            def test_foo(self):
                with self.Celery() as app:
                    ...

    It's rarely necessary to create a new app though, as the default app
    set up in AppCase (`self.app`) is usually sufficient.  Before you would
    have to create new apps if you wanted to make changes to the app
    without causing side effects for later tests, but now the `self.app`
    is destroyed and recreated for every test.

    Note that this app will have the ``set_as_current`` argument set to False.
Ask Solem 11 年之前
父節點
當前提交
ab83c9ef75
共有 74 個文件被更改,包括 1326 次插入2263 次删除
  1. 15 14
      celery/tests/__init__.py
  2. 1 1
      celery/tests/app/test_amqp.py
  3. 2 2
      celery/tests/app/test_annotations.py
  4. 100 147
      celery/tests/app/test_app.py
  5. 34 37
      celery/tests/app/test_beat.py
  6. 3 3
      celery/tests/app/test_builtins.py
  7. 2 2
      celery/tests/app/test_celery.py
  8. 3 9
      celery/tests/app/test_control.py
  9. 4 4
      celery/tests/app/test_defaults.py
  10. 2 2
      celery/tests/app/test_exceptions.py
  11. 27 36
      celery/tests/app/test_loaders.py
  12. 8 11
      celery/tests/app/test_log.py
  13. 34 48
      celery/tests/app/test_registry.py
  14. 62 70
      celery/tests/app/test_routes.py
  15. 2 2
      celery/tests/app/test_schedules.py
  16. 2 2
      celery/tests/app/test_utils.py
  17. 7 8
      celery/tests/backends/test_amqp.py
  18. 13 11
      celery/tests/backends/test_backends.py
  19. 10 14
      celery/tests/backends/test_base.py
  20. 7 13
      celery/tests/backends/test_cache.py
  21. 22 25
      celery/tests/backends/test_cassandra.py
  22. 46 58
      celery/tests/backends/test_couchbase.py
  23. 20 22
      celery/tests/backends/test_database.py
  24. 6 8
      celery/tests/backends/test_mongodb.py
  25. 14 21
      celery/tests/backends/test_redis.py
  26. 0 2
      celery/tests/bin/test_amqp.py
  27. 16 16
      celery/tests/bin/test_base.py
  28. 5 7
      celery/tests/bin/test_beat.py
  29. 30 25
      celery/tests/bin/test_celery.py
  30. 10 9
      celery/tests/bin/test_celeryd_detach.py
  31. 4 4
      celery/tests/bin/test_celeryevdump.py
  32. 6 6
      celery/tests/bin/test_multi.py
  33. 53 74
      celery/tests/bin/test_worker.py
  34. 128 47
      celery/tests/case.py
  35. 63 10
      celery/tests/compat_modules/test_compat.py
  36. 6 7
      celery/tests/compat_modules/test_compat_utils.py
  37. 4 3
      celery/tests/compat_modules/test_decorators.py
  38. 13 10
      celery/tests/compat_modules/test_http.py
  39. 3 2
      celery/tests/compat_modules/test_messaging.py
  40. 88 42
      celery/tests/compat_modules/test_sets.py
  41. 2 2
      celery/tests/concurrency/test_concurrency.py
  42. 4 4
      celery/tests/concurrency/test_eventlet.py
  43. 7 7
      celery/tests/concurrency/test_gevent.py
  44. 3 3
      celery/tests/concurrency/test_pool.py
  45. 2 2
      celery/tests/concurrency/test_solo.py
  46. 2 2
      celery/tests/concurrency/test_threads.py
  47. 0 44
      celery/tests/config.py
  48. 24 26
      celery/tests/contrib/test_abortable.py
  49. 1 1
      celery/tests/contrib/test_methods.py
  50. 5 5
      celery/tests/contrib/test_migrate.py
  51. 10 13
      celery/tests/events/test_events.py
  52. 4 4
      celery/tests/events/test_state.py
  53. 5 7
      celery/tests/fixups/test_django.py
  54. 13 9
      celery/tests/functional/case.py
  55. 25 37
      celery/tests/security/test_security.py
  56. 13 20
      celery/tests/tasks/test_canvas.py
  57. 13 19
      celery/tests/tasks/test_chord.py
  58. 3 3
      celery/tests/tasks/test_context.py
  59. 16 16
      celery/tests/tasks/test_result.py
  60. 49 815
      celery/tests/tasks/test_tasks.py
  61. 29 26
      celery/tests/tasks/test_trace.py
  62. 7 8
      celery/tests/utils/test_dispatcher.py
  63. 22 8
      celery/tests/utils/test_saferef.py
  64. 4 6
      celery/tests/utils/test_text.py
  65. 6 6
      celery/tests/worker/test_bootsteps.py
  66. 14 27
      celery/tests/worker/test_consumer.py
  67. 59 76
      celery/tests/worker/test_control.py
  68. 2 2
      celery/tests/worker/test_heartbeat.py
  69. 1 1
      celery/tests/worker/test_loops.py
  70. 27 30
      celery/tests/worker/test_request.py
  71. 2 2
      celery/tests/worker/test_revoke.py
  72. 8 15
      celery/tests/worker/test_state.py
  73. 9 12
      celery/tests/worker/test_strategy.py
  74. 60 171
      celery/tests/worker/test_worker.py

+ 15 - 14
celery/tests/__init__.py

@@ -14,20 +14,16 @@ except NameError:
     class WindowsError(Exception):
     class WindowsError(Exception):
         pass
         pass
 
 
-config_module = os.environ.setdefault(
-    'CELERY_TEST_CONFIG_MODULE', 'celery.tests.config',
+os.environ.update(
+    #: warn if config module not found
+    C_WNOCONF='yes',
+    EVENTLET_NOPATCH='yes',
+    GEVENT_NOPATCH='yes',
+    KOMBU_DISABLE_LIMIT_PROTECTION='yes',
+    # virtual.QoS will not do sanity assertions when this is set.
+    KOMBU_UNITTEST='yes',
 )
 )
 
 
-os.environ.setdefault('CELERY_CONFIG_MODULE', config_module)
-os.environ['CELERY_LOADER'] = 'default'
-os.environ['EVENTLET_NOPATCH'] = 'yes'
-os.environ['GEVENT_NOPATCH'] = 'yes'
-os.environ['KOMBU_DISABLE_LIMIT_PROTECTION'] = 'yes'
-os.environ['CELERY_BROKER_URL'] = 'memory://'
-
-# virtual.QoS will not do sanity assertions when this is set.
-os.environ['KOMBU_UNITTEST'] = 'yes'
-
 
 
 def setup():
 def setup():
     if os.environ.get('COVER_ALL_MODULES') or '--with-coverage3' in sys.argv:
     if os.environ.get('COVER_ALL_MODULES') or '--with-coverage3' in sys.argv:
@@ -35,6 +31,9 @@ def setup():
         with catch_warnings(record=True):
         with catch_warnings(record=True):
             import_all_modules()
             import_all_modules()
         warnings.resetwarnings()
         warnings.resetwarnings()
+    from celery.tests.case import Trap
+    from celery._state import set_default_app
+    set_default_app(Trap())
 
 
 
 
 def teardown():
 def teardown():
@@ -81,9 +80,11 @@ def find_distribution_modules(name=__name__, file=__file__):
 
 
 
 
 def import_all_modules(name=__name__, file=__file__,
 def import_all_modules(name=__name__, file=__file__,
-                       skip=['celery.decorators', 'celery.contrib.batches']):
+                       skip=('celery.decorators',
+                             'celery.contrib.batches',
+                             'celery.task')):
     for module in find_distribution_modules(name, file):
     for module in find_distribution_modules(name, file):
-        if module not in skip:
+        if not module.startswith(skip):
             try:
             try:
                 import_module(module)
                 import_module(module)
             except ImportError:
             except ImportError:

+ 1 - 1
celery/tests/app/test_amqp.py

@@ -81,7 +81,7 @@ class test_compat_TaskPublisher(AppCase):
         self.assertEqual(producer.exchange.type, 'topic')
         self.assertEqual(producer.exchange.type, 'topic')
 
 
     def test_compat_exchange_is_Exchange(self):
     def test_compat_exchange_is_Exchange(self):
-        producer = TaskPublisher(exchange=Exchange('foo'))
+        producer = TaskPublisher(exchange=Exchange('foo'), app=self.app)
         self.assertEqual(producer.exchange.name, 'foo')
         self.assertEqual(producer.exchange.name, 'foo')
 
 
 
 

+ 2 - 2
celery/tests/app/test_annotations.py

@@ -13,12 +13,12 @@ class MyAnnotation(object):
 class AnnotationCase(AppCase):
 class AnnotationCase(AppCase):
 
 
     def setup(self):
     def setup(self):
-        @self.app.task()
+        @self.app.task(shared=False)
         def add(x, y):
         def add(x, y):
             return x + y
             return x + y
         self.add = add
         self.add = add
 
 
-        @self.app.task()
+        @self.app.task(shared=False)
         def mul(x, y):
         def mul(x, y):
             return x * y
             return x * y
         self.mul = mul
         self.mul = mul

+ 100 - 147
celery/tests/app/test_app.py

@@ -7,7 +7,7 @@ from pickle import loads, dumps
 
 
 from kombu import Exchange
 from kombu import Exchange
 
 
-from celery import Celery, shared_task, current_app
+from celery import shared_task, current_app
 from celery import app as _app
 from celery import app as _app
 from celery import _state
 from celery import _state
 from celery.app import base as _appbase
 from celery.app import base as _appbase
@@ -20,11 +20,13 @@ from celery.utils.serialization import pickle
 
 
 from celery.tests import config
 from celery.tests import config
 from celery.tests.case import (
 from celery.tests.case import (
-    AppCase, Case,
+    AppCase,
+    depends_on_current_app,
     mask_modules,
     mask_modules,
     platform_pyimp,
     platform_pyimp,
     sys_platform,
     sys_platform,
     pypy_version,
     pypy_version,
+    with_environ,
 )
 )
 from celery.utils import uuid
 from celery.utils import uuid
 from celery.utils.mail import ErrorMail
 from celery.utils.mail import ErrorMail
@@ -55,42 +57,41 @@ def _get_test_config():
 test_config = _get_test_config()
 test_config = _get_test_config()
 
 
 
 
-class test_module(Case):
+class test_module(AppCase):
 
 
     def test_default_app(self):
     def test_default_app(self):
         self.assertEqual(_app.default_app, _state.default_app)
         self.assertEqual(_app.default_app, _state.default_app)
 
 
     def test_bugreport(self):
     def test_bugreport(self):
-        self.assertTrue(_app.bugreport())
+        self.assertTrue(_app.bugreport(app=self.app))
 
 
 
 
 class test_App(AppCase):
 class test_App(AppCase):
 
 
     def setup(self):
     def setup(self):
-        self.app.conf.update(test_config)
+        self.app.add_defaults(test_config)
 
 
     def test_task(self):
     def test_task(self):
-        app = Celery('foozibari', set_as_current=False)
+        with self.Celery('foozibari') as app:
 
 
-        def fun():
-            pass
+            def fun():
+                pass
 
 
-        fun.__module__ = '__main__'
-        task = app.task(fun)
-        self.assertEqual(task.name, app.main + '.fun')
+            fun.__module__ = '__main__'
+            task = app.task(fun)
+            self.assertEqual(task.name, app.main + '.fun')
 
 
     def test_with_config_source(self):
     def test_with_config_source(self):
-        app = Celery(set_as_current=False, config_source=ObjectConfig)
-        self.assertEqual(app.conf.FOO, 1)
-        self.assertEqual(app.conf.BAR, 2)
+        with self.Celery(config_source=ObjectConfig) as app:
+            self.assertEqual(app.conf.FOO, 1)
+            self.assertEqual(app.conf.BAR, 2)
 
 
+    @depends_on_current_app
     def test_task_windows_execv(self):
     def test_task_windows_execv(self):
-        app = Celery(set_as_current=False)
-
         prev, _appbase._EXECV = _appbase._EXECV, True
         prev, _appbase._EXECV = _appbase._EXECV, True
         try:
         try:
 
 
-            @app.task()
+            @self.app.task(shared=False)
             def foo():
             def foo():
                 pass
                 pass
 
 
@@ -101,41 +102,36 @@ class test_App(AppCase):
         assert not _appbase._EXECV
         assert not _appbase._EXECV
 
 
     def test_task_takes_no_args(self):
     def test_task_takes_no_args(self):
-        app = Celery(set_as_current=False)
-
         with self.assertRaises(TypeError):
         with self.assertRaises(TypeError):
-            @app.task(1)
+            @self.app.task(1)
             def foo():
             def foo():
                 pass
                 pass
 
 
     def test_add_defaults(self):
     def test_add_defaults(self):
-        app = Celery(set_as_current=False)
-
-        self.assertFalse(app.configured)
+        self.assertFalse(self.app.configured)
         _conf = {'FOO': 300}
         _conf = {'FOO': 300}
         conf = lambda: _conf
         conf = lambda: _conf
-        app.add_defaults(conf)
-        self.assertIn(conf, app._pending_defaults)
-        self.assertFalse(app.configured)
-        self.assertEqual(app.conf.FOO, 300)
-        self.assertTrue(app.configured)
-        self.assertFalse(app._pending_defaults)
+        self.app.add_defaults(conf)
+        self.assertIn(conf, self.app._pending_defaults)
+        self.assertFalse(self.app.configured)
+        self.assertEqual(self.app.conf.FOO, 300)
+        self.assertTrue(self.app.configured)
+        self.assertFalse(self.app._pending_defaults)
 
 
         # defaults not pickled
         # defaults not pickled
-        appr = loads(dumps(app))
+        appr = loads(dumps(self.app))
         with self.assertRaises(AttributeError):
         with self.assertRaises(AttributeError):
             appr.conf.FOO
             appr.conf.FOO
 
 
         # add more defaults after configured
         # add more defaults after configured
         conf2 = {'FOO': 'BAR'}
         conf2 = {'FOO': 'BAR'}
-        app.add_defaults(conf2)
-        self.assertEqual(app.conf.FOO, 'BAR')
+        self.app.add_defaults(conf2)
+        self.assertEqual(self.app.conf.FOO, 'BAR')
 
 
-        self.assertIn(_conf, app.conf.defaults)
-        self.assertIn(conf2, app.conf.defaults)
+        self.assertIn(_conf, self.app.conf.defaults)
+        self.assertIn(conf2, self.app.conf.defaults)
 
 
     def test_connection_or_acquire(self):
     def test_connection_or_acquire(self):
-
         with self.app.connection_or_acquire(block=True):
         with self.app.connection_or_acquire(block=True):
             self.assertTrue(self.app.pool._dirty)
             self.assertTrue(self.app.pool._dirty)
 
 
@@ -173,118 +169,90 @@ class test_App(AppCase):
             self.app.autodiscover_tasks(['proj.A', 'proj.B'])
             self.app.autodiscover_tasks(['proj.A', 'proj.B'])
             self.assertFalse(ep.called)
             self.assertFalse(ep.called)
 
 
+    @with_environ('CELERY_BROKER_URL', '')
     def test_with_broker(self):
     def test_with_broker(self):
-        prev = os.environ.get('CELERY_BROKER_URL')
-        os.environ.pop('CELERY_BROKER_URL', None)
-        try:
-            app = Celery(set_as_current=False, broker='foo://baribaz')
+        with self.Celery(broker='foo://baribaz') as app:
             self.assertEqual(app.conf.BROKER_HOST, 'foo://baribaz')
             self.assertEqual(app.conf.BROKER_HOST, 'foo://baribaz')
-        finally:
-            os.environ['CELERY_BROKER_URL'] = prev
 
 
     def test_repr(self):
     def test_repr(self):
         self.assertTrue(repr(self.app))
         self.assertTrue(repr(self.app))
 
 
     def test_custom_task_registry(self):
     def test_custom_task_registry(self):
-        app1 = Celery(set_as_current=False)
-        app2 = Celery(set_as_current=False, tasks=app1.tasks)
-        self.assertIs(app2.tasks, app1.tasks)
+        with self.Celery(tasks=self.app.tasks) as app2:
+            self.assertIs(app2.tasks, self.app.tasks)
 
 
     def test_include_argument(self):
     def test_include_argument(self):
-        app = Celery(set_as_current=False, include=('foo', 'bar.foo'))
-        self.assertEqual(app.conf.CELERY_IMPORTS, ('foo', 'bar.foo'))
+        with self.Celery(include=('foo', 'bar.foo')) as app:
+            self.assertEqual(app.conf.CELERY_IMPORTS, ('foo', 'bar.foo'))
 
 
     def test_set_as_current(self):
     def test_set_as_current(self):
         current = _state._tls.current_app
         current = _state._tls.current_app
         try:
         try:
-            app = Celery(set_as_current=True)
+            app = self.Celery(set_as_current=True)
             self.assertIs(_state._tls.current_app, app)
             self.assertIs(_state._tls.current_app, app)
         finally:
         finally:
             _state._tls.current_app = current
             _state._tls.current_app = current
 
 
     def test_current_task(self):
     def test_current_task(self):
-        app = Celery(set_as_current=False)
-
-        @app.task
-        def foo():
+        @self.app.task
+        def foo(shared=False):
             pass
             pass
 
 
         _state._task_stack.push(foo)
         _state._task_stack.push(foo)
         try:
         try:
-            self.assertEqual(app.current_task.name, foo.name)
+            self.assertEqual(self.app.current_task.name, foo.name)
         finally:
         finally:
             _state._task_stack.pop()
             _state._task_stack.pop()
 
 
     def test_task_not_shared(self):
     def test_task_not_shared(self):
         with patch('celery.app.base.shared_task') as sh:
         with patch('celery.app.base.shared_task') as sh:
-            app = Celery(set_as_current=False)
-
-            @app.task(shared=False)
+            @self.app.task(shared=False)
             def foo():
             def foo():
                 pass
                 pass
             self.assertFalse(sh.called)
             self.assertFalse(sh.called)
 
 
     def test_task_compat_with_filter(self):
     def test_task_compat_with_filter(self):
-        app = Celery(set_as_current=False, accept_magic_kwargs=True)
-        check = Mock()
+        with self.Celery(accept_magic_kwargs=True) as app:
+            check = Mock()
 
 
-        def filter(task):
-            check(task)
-            return task
+            def filter(task):
+                check(task)
+                return task
 
 
-        @app.task(filter=filter)
-        def foo():
-            pass
-        check.assert_called_with(foo)
+            @app.task(filter=filter, shared=False)
+            def foo():
+                pass
+            check.assert_called_with(foo)
 
 
     def test_task_with_filter(self):
     def test_task_with_filter(self):
-        app = Celery(set_as_current=False, accept_magic_kwargs=False)
-        check = Mock()
+        with self.Celery(accept_magic_kwargs=False) as app:
+            check = Mock()
 
 
-        def filter(task):
-            check(task)
-            return task
+            def filter(task):
+                check(task)
+                return task
 
 
-        assert not _appbase._EXECV
+            assert not _appbase._EXECV
 
 
-        @app.task(filter=filter)
-        def foo():
-            pass
-        check.assert_called_with(foo)
+            @app.task(filter=filter, shared=False)
+            def foo():
+                pass
+            check.assert_called_with(foo)
 
 
     def test_task_sets_main_name_MP_MAIN_FILE(self):
     def test_task_sets_main_name_MP_MAIN_FILE(self):
         from celery import utils as _utils
         from celery import utils as _utils
         _utils.MP_MAIN_FILE = __file__
         _utils.MP_MAIN_FILE = __file__
         try:
         try:
-            app = Celery('xuzzy', set_as_current=False)
+            with self.Celery('xuzzy') as app:
 
 
-            @app.task
-            def foo():
-                pass
+                @app.task
+                def foo():
+                    pass
 
 
-            self.assertEqual(foo.name, 'xuzzy.foo')
+                self.assertEqual(foo.name, 'xuzzy.foo')
         finally:
         finally:
             _utils.MP_MAIN_FILE = None
             _utils.MP_MAIN_FILE = None
 
 
-    def test_base_task_inherits_magic_kwargs_from_app(self):
-        from celery.task import Task as OldTask
-
-        class timkX(OldTask):
-            abstract = True
-
-        app = Celery(set_as_current=False, accept_magic_kwargs=True)
-        timkX.bind(app)
-        # see #918
-        self.assertFalse(timkX.accept_magic_kwargs)
-
-        from celery import Task as NewTask
-
-        class timkY(NewTask):
-            abstract = True
-
-        timkY.bind(app)
-        self.assertFalse(timkY.accept_magic_kwargs)
-
     def test_annotate_decorator(self):
     def test_annotate_decorator(self):
         from celery.app.task import Task
         from celery.app.task import Task
 
 
@@ -303,12 +271,11 @@ class test_App(AppCase):
                 return fun(*args, **kwargs)
                 return fun(*args, **kwargs)
             return _inner
             return _inner
 
 
-        app = Celery(set_as_current=False)
-        app.conf.CELERY_ANNOTATIONS = {
+        self.app.conf.CELERY_ANNOTATIONS = {
             adX.name: {'@__call__': deco}
             adX.name: {'@__call__': deco}
         }
         }
-        adX.bind(app)
-        self.assertIs(adX.app, app)
+        adX.bind(self.app)
+        self.assertIs(adX.app, self.app)
 
 
         i = adX()
         i = adX()
         i(2, 4, x=3)
         i(2, 4, x=3)
@@ -318,9 +285,7 @@ class test_App(AppCase):
         i.annotate()
         i.annotate()
 
 
     def test_apply_async_has__self__(self):
     def test_apply_async_has__self__(self):
-        app = Celery(set_as_current=False)
-
-        @app.task(__self__='hello')
+        @self.app.task(__self__='hello', shared=False)
         def aawsX():
         def aawsX():
             pass
             pass
 
 
@@ -330,25 +295,22 @@ class test_App(AppCase):
             self.assertEqual(args, ('hello', 4, 5))
             self.assertEqual(args, ('hello', 4, 5))
 
 
     def test_apply_async__connection_arg(self):
     def test_apply_async__connection_arg(self):
-        app = Celery(set_as_current=False)
-
-        @app.task()
+        @self.app.task(shared=False)
         def aacaX():
         def aacaX():
             pass
             pass
 
 
-        connection = app.connection('asd://')
+        connection = self.app.connection('asd://')
         with self.assertRaises(KeyError):
         with self.assertRaises(KeyError):
             aacaX.apply_async(connection=connection)
             aacaX.apply_async(connection=connection)
 
 
     def test_apply_async_adds_children(self):
     def test_apply_async_adds_children(self):
         from celery._state import _task_stack
         from celery._state import _task_stack
-        app = Celery(set_as_current=False)
 
 
-        @app.task()
+        @self.app.task(shared=False)
         def a3cX1(self):
         def a3cX1(self):
             pass
             pass
 
 
-        @app.task()
+        @self.app.task(shared=False)
         def a3cX2(self):
         def a3cX2(self):
             pass
             pass
 
 
@@ -363,11 +325,6 @@ class test_App(AppCase):
         finally:
         finally:
             _task_stack.pop()
             _task_stack.pop()
 
 
-    def test_TaskSet(self):
-        ts = self.app.TaskSet()
-        self.assertListEqual(ts.tasks, [])
-        self.assertIs(ts.app, self.app)
-
     def test_pickle_app(self):
     def test_pickle_app(self):
         changes = dict(THE_FOO_BAR='bars',
         changes = dict(THE_FOO_BAR='bars',
                        THE_MII_MAR='jars')
                        THE_MII_MAR='jars')
@@ -461,19 +418,21 @@ class test_App(AppCase):
         x = self.app.Worker
         x = self.app.Worker
         self.assertIs(x.app, self.app)
         self.assertIs(x.app, self.app)
 
 
+    @depends_on_current_app
     def test_AsyncResult(self):
     def test_AsyncResult(self):
         x = self.app.AsyncResult('1')
         x = self.app.AsyncResult('1')
         self.assertIs(x.app, self.app)
         self.assertIs(x.app, self.app)
         r = loads(dumps(x))
         r = loads(dumps(x))
         # not set as current, so ends up as default app after reduce
         # not set as current, so ends up as default app after reduce
-        self.assertIs(r.app, _state.default_app)
+        self.assertIs(r.app, current_app._get_current_object())
 
 
     def test_get_active_apps(self):
     def test_get_active_apps(self):
         self.assertTrue(list(_state._get_active_apps()))
         self.assertTrue(list(_state._get_active_apps()))
 
 
-        app1 = Celery(set_as_current=False)
+        app1 = self.Celery()
         appid = id(app1)
         appid = id(app1)
         self.assertIn(app1, _state._get_active_apps())
         self.assertIn(app1, _state._get_active_apps())
+        app1.close()
         del(app1)
         del(app1)
 
 
         # weakref removed from list when app goes out of scope.
         # weakref removed from list when app goes out of scope.
@@ -607,7 +566,7 @@ class test_App(AppCase):
         self.assertFalse(task.app.mail_admins.called)
         self.assertFalse(task.app.mail_admins.called)
 
 
 
 
-class test_defaults(Case):
+class test_defaults(AppCase):
 
 
     def test_str_to_bool(self):
     def test_str_to_bool(self):
         for s in ('false', 'no', '0'):
         for s in ('false', 'no', '0'):
@@ -618,7 +577,7 @@ class test_defaults(Case):
             defaults.strtobool('unsure')
             defaults.strtobool('unsure')
 
 
 
 
-class test_debugging_utils(Case):
+class test_debugging_utils(AppCase):
 
 
     def test_enable_disable_trace(self):
     def test_enable_disable_trace(self):
         try:
         try:
@@ -630,7 +589,7 @@ class test_debugging_utils(Case):
             _app.disable_trace()
             _app.disable_trace()
 
 
 
 
-class test_pyimplementation(Case):
+class test_pyimplementation(AppCase):
 
 
     def test_platform_python_implementation(self):
     def test_platform_python_implementation(self):
         with platform_pyimp(lambda: 'Xython'):
         with platform_pyimp(lambda: 'Xython'):
@@ -656,36 +615,30 @@ class test_pyimplementation(Case):
                     self.assertEqual('CPython', pyimplementation())
                     self.assertEqual('CPython', pyimplementation())
 
 
 
 
-class test_shared_task(Case):
-
-    def setUp(self):
-        self._restore_app = current_app._get_current_object()
-
-    def tearDown(self):
-        self._restore_app.set_current()
+class test_shared_task(AppCase):
 
 
     def test_registers_to_all_apps(self):
     def test_registers_to_all_apps(self):
-        xproj = Celery('xproj')
-        xproj.finalize()
+        with self.Celery('xproj', set_as_current=True) as xproj:
+            xproj.finalize()
 
 
-        @shared_task
-        def foo():
-            return 42
+            @shared_task
+            def foo():
+                return 42
 
 
-        @shared_task()
-        def bar():
-            return 84
+            @shared_task()
+            def bar():
+                return 84
 
 
-        self.assertIs(foo.app, xproj)
-        self.assertIs(bar.app, xproj)
-        self.assertTrue(foo._get_current_object())
+            self.assertIs(foo.app, xproj)
+            self.assertIs(bar.app, xproj)
+            self.assertTrue(foo._get_current_object())
 
 
-        yproj = Celery('yproj')
-        self.assertIs(foo.app, yproj)
-        self.assertIs(bar.app, yproj)
+            with self.Celery('yproj', set_as_current=True) as yproj:
+                self.assertIs(foo.app, yproj)
+                self.assertIs(bar.app, yproj)
 
 
-        @shared_task()
-        def baz():
-            return 168
+                @shared_task()
+                def baz():
+                    return 168
 
 
-        self.assertIs(baz.app, yproj)
+                self.assertIs(baz.app, yproj)

+ 34 - 37
celery/tests/app/test_beat.py

@@ -8,12 +8,10 @@ from nose import SkipTest
 from pickle import dumps, loads
 from pickle import dumps, loads
 
 
 from celery import beat
 from celery import beat
-from celery import task
 from celery.five import keys, string_t
 from celery.five import keys, string_t
-from celery.result import AsyncResult
 from celery.schedules import schedule
 from celery.schedules import schedule
 from celery.utils import uuid
 from celery.utils import uuid
-from celery.tests.case import AppCase, patch_settings
+from celery.tests.case import AppCase
 
 
 
 
 class Object(object):
 class Object(object):
@@ -49,10 +47,13 @@ class test_ScheduleEntry(AppCase):
     Entry = beat.ScheduleEntry
     Entry = beat.ScheduleEntry
 
 
     def create_entry(self, **kwargs):
     def create_entry(self, **kwargs):
-        entry = dict(name='celery.unittest.add',
-                     schedule=schedule(timedelta(seconds=10)),
-                     args=(2, 2),
-                     options={'routing_key': 'cpu'})
+        entry = dict(
+            name='celery.unittest.add',
+            schedule=timedelta(seconds=10),
+            args=(2, 2),
+            options={'routing_key': 'cpu'},
+            app=self.app,
+        )
         return self.Entry(**dict(entry, **kwargs))
         return self.Entry(**dict(entry, **kwargs))
 
 
     def test_next(self):
     def test_next(self):
@@ -68,6 +69,8 @@ class test_ScheduleEntry(AppCase):
 
 
     def test_is_due(self):
     def test_is_due(self):
         entry = self.create_entry(schedule=timedelta(seconds=10))
         entry = self.create_entry(schedule=timedelta(seconds=10))
+        self.assertIs(entry.app, self.app)
+        self.assertIs(entry.schedule.app, self.app)
         due1, next_time_to_run1 = entry.is_due()
         due1, next_time_to_run1 = entry.is_due()
         self.assertFalse(due1)
         self.assertFalse(due1)
         self.assertGreater(next_time_to_run1, 9)
         self.assertGreater(next_time_to_run1, 9)
@@ -111,7 +114,7 @@ class mScheduler(beat.Scheduler):
                           'args': args,
                           'args': args,
                           'kwargs': kwargs,
                           'kwargs': kwargs,
                           'options': options})
                           'options': options})
-        return AsyncResult(uuid(), app=self.app)
+        return self.app.AsyncResult(uuid())
 
 
 
 
 class mSchedulerSchedulingError(mScheduler):
 class mSchedulerSchedulingError(mScheduler):
@@ -151,19 +154,19 @@ class test_Scheduler(AppCase):
 
 
     def test_apply_async_uses_registered_task_instances(self):
     def test_apply_async_uses_registered_task_instances(self):
 
 
-        @self.app.task
+        @self.app.task(shared=False)
         def foo():
         def foo():
             pass
             pass
         foo.apply_async = Mock(name='foo.apply_async')
         foo.apply_async = Mock(name='foo.apply_async')
         assert foo.name in foo._get_app().tasks
         assert foo.name in foo._get_app().tasks
 
 
         scheduler = mScheduler(app=self.app)
         scheduler = mScheduler(app=self.app)
-        scheduler.apply_async(scheduler.Entry(task=foo.name))
+        scheduler.apply_async(scheduler.Entry(task=foo.name, app=self.app))
         self.assertTrue(foo.apply_async.called)
         self.assertTrue(foo.apply_async.called)
 
 
     def test_apply_async_should_not_sync(self):
     def test_apply_async_should_not_sync(self):
 
 
-        @task()
+        @self.app.task(shared=False)
         def not_sync():
         def not_sync():
             pass
             pass
         not_sync.apply_async = Mock()
         not_sync.apply_async = Mock()
@@ -172,12 +175,12 @@ class test_Scheduler(AppCase):
         s._do_sync = Mock()
         s._do_sync = Mock()
         s.should_sync = Mock()
         s.should_sync = Mock()
         s.should_sync.return_value = True
         s.should_sync.return_value = True
-        s.apply_async(s.Entry(task=not_sync.name))
+        s.apply_async(s.Entry(task=not_sync.name, app=self.app))
         s._do_sync.assert_called_with()
         s._do_sync.assert_called_with()
 
 
         s._do_sync = Mock()
         s._do_sync = Mock()
         s.should_sync.return_value = False
         s.should_sync.return_value = False
-        s.apply_async(s.Entry(task=not_sync.name))
+        s.apply_async(s.Entry(task=not_sync.name, app=self.app))
         self.assertFalse(s._do_sync.called)
         self.assertFalse(s._do_sync.called)
 
 
     @patch('celery.app.base.Celery.send_task')
     @patch('celery.app.base.Celery.send_task')
@@ -192,7 +195,7 @@ class test_Scheduler(AppCase):
 
 
     def test_maybe_entry(self):
     def test_maybe_entry(self):
         s = mScheduler(app=self.app)
         s = mScheduler(app=self.app)
-        entry = s.Entry(name='add every', task='tasks.add')
+        entry = s.Entry(name='add every', task='tasks.add', app=self.app)
         self.assertIs(s._maybe_entry(entry.name, entry), entry)
         self.assertIs(s._maybe_entry(entry.name, entry), entry)
         self.assertTrue(s._maybe_entry('add every', {
         self.assertTrue(s._maybe_entry('add every', {
             'task': 'tasks.add',
             'task': 'tasks.add',
@@ -213,29 +216,23 @@ class test_Scheduler(AppCase):
         callback(KeyError(), 5)
         callback(KeyError(), 5)
 
 
     def test_install_default_entries(self):
     def test_install_default_entries(self):
-        with patch_settings(self.app,
-                            CELERY_TASK_RESULT_EXPIRES=None,
-                            CELERYBEAT_SCHEDULE={}):
-            s = mScheduler(app=self.app)
-            s.install_default_entries({})
-            self.assertNotIn('celery.backend_cleanup', s.data)
+        self.app.conf.CELERY_TASK_RESULT_EXPIRES = None
+        self.app.conf.CELERYBEAT_SCHEDULE = {}
+        s = mScheduler(app=self.app)
+        s.install_default_entries({})
+        self.assertNotIn('celery.backend_cleanup', s.data)
         self.app.backend.supports_autoexpire = False
         self.app.backend.supports_autoexpire = False
-        with patch_settings(self.app,
-                            CELERY_TASK_RESULT_EXPIRES=30,
-                            CELERYBEAT_SCHEDULE={}):
-            s = mScheduler(app=self.app)
-            s.install_default_entries({})
-            self.assertIn('celery.backend_cleanup', s.data)
+
+        self.app.conf.CELERY_TASK_RESULT_EXPIRES = 30
+        s = mScheduler(app=self.app)
+        s.install_default_entries({})
+        self.assertIn('celery.backend_cleanup', s.data)
+
         self.app.backend.supports_autoexpire = True
         self.app.backend.supports_autoexpire = True
-        try:
-            with patch_settings(self.app,
-                                CELERY_TASK_RESULT_EXPIRES=31,
-                                CELERYBEAT_SCHEDULE={}):
-                s = mScheduler(app=self.app)
-                s.install_default_entries({})
-                self.assertNotIn('celery.backend_cleanup', s.data)
-        finally:
-            self.app.backend.supports_autoexpire = False
+        self.app.conf.CELERY_TASK_RESULT_EXPIRES = 31
+        s = mScheduler(app=self.app)
+        s.install_default_entries({})
+        self.assertNotIn('celery.backend_cleanup', s.data)
 
 
     def test_due_tick(self):
     def test_due_tick(self):
         scheduler = mScheduler(app=self.app)
         scheduler = mScheduler(app=self.app)
@@ -490,7 +487,7 @@ class test_EmbeddedService(AppCase):
 class test_schedule(AppCase):
 class test_schedule(AppCase):
 
 
     def test_maybe_make_aware(self):
     def test_maybe_make_aware(self):
-        x = schedule(10)
+        x = schedule(10, app=self.app)
         x.utc_enabled = True
         x.utc_enabled = True
         d = x.maybe_make_aware(datetime.utcnow())
         d = x.maybe_make_aware(datetime.utcnow())
         self.assertTrue(d.tzinfo)
         self.assertTrue(d.tzinfo)
@@ -499,7 +496,7 @@ class test_schedule(AppCase):
         self.assertIsNone(d2.tzinfo)
         self.assertIsNone(d2.tzinfo)
 
 
     def test_to_local(self):
     def test_to_local(self):
-        x = schedule(10)
+        x = schedule(10, app=self.app)
         x.utc_enabled = True
         x.utc_enabled = True
         d = x.to_local(datetime.utcnow())
         d = x.to_local(datetime.utcnow())
         self.assertIsNone(d.tzinfo)
         self.assertIsNone(d.tzinfo)

+ 3 - 3
celery/tests/app/test_builtins.py

@@ -38,7 +38,7 @@ class test_map(BuiltinsCase):
 
 
     def test_run(self):
     def test_run(self):
 
 
-        @self.app.task()
+        @self.app.task(shared=False)
         def map_mul(x):
         def map_mul(x):
             return x[0] * x[1]
             return x[0] * x[1]
 
 
@@ -52,7 +52,7 @@ class test_starmap(BuiltinsCase):
 
 
     def test_run(self):
     def test_run(self):
 
 
-        @self.app.task()
+        @self.app.task(shared=False)
         def smap_mul(x, y):
         def smap_mul(x, y):
             return x * y
             return x * y
 
 
@@ -67,7 +67,7 @@ class test_chunks(BuiltinsCase):
     @patch('celery.canvas.chunks.apply_chunks')
     @patch('celery.canvas.chunks.apply_chunks')
     def test_run(self, apply_chunks):
     def test_run(self, apply_chunks):
 
 
-        @self.app.task()
+        @self.app.task(shared=False)
         def chunks_mul(l):
         def chunks_mul(l):
             return l
             return l
 
 

+ 2 - 2
celery/tests/app/test_celery.py

@@ -1,10 +1,10 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
-from celery.tests.case import Case
+from celery.tests.case import AppCase
 
 
 import celery
 import celery
 
 
 
 
-class test_celery_package(Case):
+class test_celery_package(AppCase):
 
 
     def test_version(self):
     def test_version(self):
         self.assertTrue(celery.VERSION)
         self.assertTrue(celery.VERSION)

+ 3 - 9
celery/tests/app/test_control.py

@@ -8,7 +8,7 @@ from kombu.pidbox import Mailbox
 
 
 from celery.app import control
 from celery.app import control
 from celery.utils import uuid
 from celery.utils import uuid
-from celery.tests.case import AppCase, Case
+from celery.tests.case import AppCase
 
 
 
 
 class MockMailbox(Mailbox):
 class MockMailbox(Mailbox):
@@ -40,7 +40,7 @@ def with_mock_broadcast(fun):
     return _resets
     return _resets
 
 
 
 
-class test_flatten_reply(Case):
+class test_flatten_reply(AppCase):
 
 
     def test_flatten_reply(self):
     def test_flatten_reply(self):
         reply = [
         reply = [
@@ -65,9 +65,6 @@ class test_inspect(AppCase):
         self.prev, self.app.control = self.app.control, self.c
         self.prev, self.app.control = self.app.control, self.c
         self.i = self.c.inspect()
         self.i = self.c.inspect()
 
 
-    def tearDown(self):
-        self.app.control = self.prev
-
     def test_prepare_reply(self):
     def test_prepare_reply(self):
         self.assertDictEqual(self.i._prepare([{'w1': {'ok': 1}},
         self.assertDictEqual(self.i._prepare([{'w1': {'ok': 1}},
                                               {'w2': {'ok': 1}}]),
                                               {'w2': {'ok': 1}}]),
@@ -159,14 +156,11 @@ class test_Broadcast(AppCase):
         self.control = Control(app=self.app)
         self.control = Control(app=self.app)
         self.app.control = self.control
         self.app.control = self.control
 
 
-        @self.app.task()
+        @self.app.task(shared=False)
         def mytask():
         def mytask():
             pass
             pass
         self.mytask = mytask
         self.mytask = mytask
 
 
-    def tearDown(self):
-        del(self.app.control)
-
     def test_purge(self):
     def test_purge(self):
         self.control.purge()
         self.control.purge()
 
 

+ 4 - 4
celery/tests/app/test_defaults.py

@@ -7,15 +7,15 @@ from mock import Mock, patch
 
 
 from celery.app.defaults import NAMESPACES
 from celery.app.defaults import NAMESPACES
 
 
-from celery.tests.case import Case, pypy_version, sys_platform
+from celery.tests.case import AppCase, pypy_version, sys_platform
 
 
 
 
-class test_defaults(Case):
+class test_defaults(AppCase):
 
 
-    def setUp(self):
+    def setup(self):
         self._prev = sys.modules.pop('celery.app.defaults', None)
         self._prev = sys.modules.pop('celery.app.defaults', None)
 
 
-    def tearDown(self):
+    def teardown(self):
         if self._prev:
         if self._prev:
             sys.modules['celery.app.defaults'] = self._prev
             sys.modules['celery.app.defaults'] = self._prev
 
 

+ 2 - 2
celery/tests/app/test_exceptions.py

@@ -6,10 +6,10 @@ from datetime import datetime
 
 
 from celery.exceptions import RetryTaskError
 from celery.exceptions import RetryTaskError
 
 
-from celery.tests.case import Case
+from celery.tests.case import AppCase
 
 
 
 
-class test_RetryTaskError(Case):
+class test_RetryTaskError(AppCase):
 
 
     def test_when_datetime(self):
     def test_when_datetime(self):
         x = RetryTaskError('foo', KeyError(), when=datetime.utcnow())
         x = RetryTaskError('foo', KeyError(), when=datetime.utcnow())

+ 27 - 36
celery/tests/app/test_loaders.py

@@ -17,7 +17,9 @@ from celery.loaders.app import AppLoader
 from celery.utils.imports import NotAPackage
 from celery.utils.imports import NotAPackage
 from celery.utils.mail import SendmailWarning
 from celery.utils.mail import SendmailWarning
 
 
-from celery.tests.case import AppCase, Case
+from celery.tests.case import (
+    AppCase, Case, depends_on_current_app, with_environ,
+)
 
 
 
 
 class DummyLoader(base.BaseLoader):
 class DummyLoader(base.BaseLoader):
@@ -32,15 +34,15 @@ class test_loaders(AppCase):
         self.assertEqual(loaders.get_loader_cls('default'),
         self.assertEqual(loaders.get_loader_cls('default'),
                          default.Loader)
                          default.Loader)
 
 
+    @depends_on_current_app
     def test_current_loader(self):
     def test_current_loader(self):
-        self.app.set_current()  # XXX Compat test
         with self.assertWarnsRegex(
         with self.assertWarnsRegex(
                 CPendingDeprecationWarning,
                 CPendingDeprecationWarning,
                 r'deprecation'):
                 r'deprecation'):
             self.assertIs(loaders.current_loader(), self.app.loader)
             self.assertIs(loaders.current_loader(), self.app.loader)
 
 
+    @depends_on_current_app
     def test_load_settings(self):
     def test_load_settings(self):
-        self.app.set_current()  # XXX Compat test
         with self.assertWarnsRegex(
         with self.assertWarnsRegex(
                 CPendingDeprecationWarning,
                 CPendingDeprecationWarning,
                 r'deprecation'):
                 r'deprecation'):
@@ -105,15 +107,11 @@ class test_LoaderBase(AppCase):
 
 
     def test_import_default_modules(self):
     def test_import_default_modules(self):
         modnames = lambda l: [m.__name__ for m in l]
         modnames = lambda l: [m.__name__ for m in l]
-        prev, self.app.conf.CELERY_IMPORTS = (
-            self.app.conf.CELERY_IMPORTS, ('os', 'sys'))
-        try:
-            self.assertEqual(
-                sorted(modnames(self.loader.import_default_modules())),
-                sorted(modnames([os, sys])),
-            )
-        finally:
-            self.app.conf.CELERY_IMPORTS = prev
+        self.app.conf.CELERY_IMPORTS = ('os', 'sys')
+        self.assertEqual(
+            sorted(modnames(self.loader.import_default_modules())),
+            sorted(modnames([os, sys])),
+        )
 
 
     def test_import_from_cwd_custom_imp(self):
     def test_import_from_cwd_custom_imp(self):
 
 
@@ -163,19 +161,15 @@ class test_DefaultLoader(AppCase):
         find_module.side_effect = NotAPackage()
         find_module.side_effect = NotAPackage()
         l = default.Loader(app=self.app)
         l = default.Loader(app=self.app)
         with self.assertRaises(NotAPackage):
         with self.assertRaises(NotAPackage):
-            l.read_configuration()
+            l.read_configuration(fail_silently=False)
 
 
     @patch('celery.loaders.base.find_module')
     @patch('celery.loaders.base.find_module')
+    @with_environ('CELERY_CONFIG_MODULE', 'celeryconfig.py')
     def test_read_configuration_py_in_name(self, find_module):
     def test_read_configuration_py_in_name(self, find_module):
-        prev = os.environ['CELERY_CONFIG_MODULE']
-        os.environ['CELERY_CONFIG_MODULE'] = 'celeryconfig.py'
-        try:
-            find_module.side_effect = NotAPackage()
-            l = default.Loader(app=self.app)
-            with self.assertRaises(NotAPackage):
-                l.read_configuration()
-        finally:
-            os.environ['CELERY_CONFIG_MODULE'] = prev
+        find_module.side_effect = NotAPackage()
+        l = default.Loader(app=self.app)
+        with self.assertRaises(NotAPackage):
+            l.read_configuration(fail_silently=False)
 
 
     @patch('celery.loaders.base.find_module')
     @patch('celery.loaders.base.find_module')
     def test_read_configuration_importerror(self, find_module):
     def test_read_configuration_importerror(self, find_module):
@@ -183,9 +177,9 @@ class test_DefaultLoader(AppCase):
         find_module.side_effect = ImportError()
         find_module.side_effect = ImportError()
         l = default.Loader(app=self.app)
         l = default.Loader(app=self.app)
         with self.assertWarnsRegex(NotConfigured, r'make sure it exists'):
         with self.assertWarnsRegex(NotConfigured, r'make sure it exists'):
-            l.read_configuration()
+            l.read_configuration(fail_silently=True)
         default.C_WNOCONF = False
         default.C_WNOCONF = False
-        l.read_configuration()
+        l.read_configuration(fail_silently=True)
 
 
     def test_read_configuration(self):
     def test_read_configuration(self):
         from types import ModuleType
         from types import ModuleType
@@ -193,17 +187,18 @@ class test_DefaultLoader(AppCase):
         class ConfigModule(ModuleType):
         class ConfigModule(ModuleType):
             pass
             pass
 
 
-        celeryconfig = ConfigModule('celeryconfig')
-        celeryconfig.CELERY_IMPORTS = ('os', 'sys')
         configname = os.environ.get('CELERY_CONFIG_MODULE') or 'celeryconfig'
         configname = os.environ.get('CELERY_CONFIG_MODULE') or 'celeryconfig'
+        celeryconfig = ConfigModule(configname)
+        celeryconfig.CELERY_IMPORTS = ('os', 'sys')
 
 
         prevconfig = sys.modules.get(configname)
         prevconfig = sys.modules.get(configname)
         sys.modules[configname] = celeryconfig
         sys.modules[configname] = celeryconfig
         try:
         try:
             l = default.Loader(app=self.app)
             l = default.Loader(app=self.app)
-            settings = l.read_configuration()
+            l.find_module = Mock(name='find_module')
+            settings = l.read_configuration(fail_silently=False)
             self.assertTupleEqual(settings.CELERY_IMPORTS, ('os', 'sys'))
             self.assertTupleEqual(settings.CELERY_IMPORTS, ('os', 'sys'))
-            settings = l.read_configuration()
+            settings = l.read_configuration(fail_silently=False)
             self.assertTupleEqual(settings.CELERY_IMPORTS, ('os', 'sys'))
             self.assertTupleEqual(settings.CELERY_IMPORTS, ('os', 'sys'))
             l.on_worker_init()
             l.on_worker_init()
         finally:
         finally:
@@ -248,14 +243,10 @@ class test_AppLoader(AppCase):
         self.loader = AppLoader(app=self.app)
         self.loader = AppLoader(app=self.app)
 
 
     def test_on_worker_init(self):
     def test_on_worker_init(self):
-        prev, self.app.conf.CELERY_IMPORTS = (
-            self.app.conf.CELERY_IMPORTS, ('subprocess', ))
-        try:
-            sys.modules.pop('subprocess', None)
-            self.loader.init_worker()
-            self.assertIn('subprocess', sys.modules)
-        finally:
-            self.app.conf.CELERY_IMPORTS = prev
+        self.app.conf.CELERY_IMPORTS = ('subprocess', )
+        sys.modules.pop('subprocess', None)
+        self.loader.init_worker()
+        self.assertIn('subprocess', sys.modules)
 
 
 
 
 class test_autodiscovery(Case):
 class test_autodiscovery(Case):

+ 8 - 11
celery/tests/app/test_log.py

@@ -8,7 +8,7 @@ from mock import patch, Mock
 from nose import SkipTest
 from nose import SkipTest
 
 
 from celery import signals
 from celery import signals
-from celery.app.log import Logging, TaskFormatter
+from celery.app.log import TaskFormatter
 from celery.utils.log import LoggingProxy
 from celery.utils.log import LoggingProxy
 from celery.utils import uuid
 from celery.utils import uuid
 from celery.utils.log import (
 from celery.utils.log import (
@@ -20,12 +20,12 @@ from celery.utils.log import (
     _patch_logger_class,
     _patch_logger_class,
 )
 )
 from celery.tests.case import (
 from celery.tests.case import (
-    AppCase, Case, override_stdouts, wrap_logger, get_handlers,
+    AppCase, override_stdouts, wrap_logger, get_handlers,
     restore_logging,
     restore_logging,
 )
 )
 
 
 
 
-class test_TaskFormatter(Case):
+class test_TaskFormatter(AppCase):
 
 
     def test_no_task(self):
     def test_no_task(self):
         class Record(object):
         class Record(object):
@@ -43,7 +43,7 @@ class test_TaskFormatter(Case):
         self.assertEqual(record.task_id, '???')
         self.assertEqual(record.task_id, '???')
 
 
 
 
-class test_ColorFormatter(Case):
+class test_ColorFormatter(AppCase):
 
 
     @patch('celery.utils.log.safe_str')
     @patch('celery.utils.log.safe_str')
     @patch('logging.Formatter.formatException')
     @patch('logging.Formatter.formatException')
@@ -139,10 +139,7 @@ class test_default_logger(AppCase):
     def test_setup_logging_subsystem_misc2(self):
     def test_setup_logging_subsystem_misc2(self):
         with restore_logging():
         with restore_logging():
             self.app.conf.CELERYD_HIJACK_ROOT_LOGGER = True
             self.app.conf.CELERYD_HIJACK_ROOT_LOGGER = True
-            try:
-                self.app.log.setup_logging_subsystem()
-            finally:
-                self.app.conf.CELERYD_HIJACK_ROOT_LOGGER = False
+            self.app.log.setup_logging_subsystem()
 
 
     def test_get_default_logger(self):
     def test_get_default_logger(self):
         self.assertTrue(self.app.log.get_default_logger())
         self.assertTrue(self.app.log.get_default_logger())
@@ -277,7 +274,7 @@ class test_task_logger(test_default_logger):
         logging.root.manager.loggerDict.pop(logger.name, None)
         logging.root.manager.loggerDict.pop(logger.name, None)
         self.uid = uuid()
         self.uid = uuid()
 
 
-        @self.app.task
+        @self.app.task(shared=False)
         def test_task():
         def test_task():
             pass
             pass
         self.get_logger().handlers = []
         self.get_logger().handlers = []
@@ -285,7 +282,7 @@ class test_task_logger(test_default_logger):
         from celery._state import _task_stack
         from celery._state import _task_stack
         _task_stack.push(test_task)
         _task_stack.push(test_task)
 
 
-    def tearDown(self):
+    def teardown(self):
         from celery._state import _task_stack
         from celery._state import _task_stack
         _task_stack.pop()
         _task_stack.pop()
 
 
@@ -296,7 +293,7 @@ class test_task_logger(test_default_logger):
         return get_task_logger("test_task_logger")
         return get_task_logger("test_task_logger")
 
 
 
 
-class test_patch_logger_cls(Case):
+class test_patch_logger_cls(AppCase):
 
 
     def test_patches(self):
     def test_patches(self):
         _patch_logger_class()
         _patch_logger_class()

+ 34 - 48
celery/tests/app/test_registry.py

@@ -1,51 +1,40 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
 
 
-from celery import Celery
-from celery.app.registry import (
-    TaskRegistry,
-    _unpickle_task,
-    _unpickle_task_v2,
-)
-from celery.task import Task, PeriodicTask
-from celery.tests.case import AppCase, Case
+from celery.app.registry import _unpickle_task, _unpickle_task_v2
+from celery.tests.case import AppCase, depends_on_current_app
 
 
 
 
-class MockTask(Task):
-    name = 'celery.unittest.test_task'
-
-    def run(self, **kwargs):
-        return True
-
-
-class MockPeriodicTask(PeriodicTask):
-    name = 'celery.unittest.test_periodic_task'
-    run_every = 10
-
-    def run(self, **kwargs):
-        return True
+def returns():
+    return 1
 
 
 
 
 class test_unpickle_task(AppCase):
 class test_unpickle_task(AppCase):
 
 
-    def setup(self):
-        self.app = Celery(set_as_current=True)
-
+    @depends_on_current_app
     def test_unpickle_v1(self):
     def test_unpickle_v1(self):
         self.app.tasks['txfoo'] = 'bar'
         self.app.tasks['txfoo'] = 'bar'
         self.assertEqual(_unpickle_task('txfoo'), 'bar')
         self.assertEqual(_unpickle_task('txfoo'), 'bar')
 
 
+    @depends_on_current_app
     def test_unpickle_v2(self):
     def test_unpickle_v2(self):
         self.app.tasks['txfoo1'] = 'bar1'
         self.app.tasks['txfoo1'] = 'bar1'
         self.assertEqual(_unpickle_task_v2('txfoo1'), 'bar1')
         self.assertEqual(_unpickle_task_v2('txfoo1'), 'bar1')
         self.assertEqual(_unpickle_task_v2('txfoo1', module='celery'), 'bar1')
         self.assertEqual(_unpickle_task_v2('txfoo1', module='celery'), 'bar1')
 
 
 
 
-class test_TaskRegistry(Case):
+class test_TaskRegistry(AppCase):
+
+    def setup(self):
+        self.mytask = self.app.task(name='A', shared=False)(returns)
+        self.myperiodic = self.app.task(
+            name='B', shared=False, type='periodic',
+        )(returns)
 
 
     def test_NotRegistered_str(self):
     def test_NotRegistered_str(self):
-        self.assertTrue(repr(TaskRegistry.NotRegistered('tasks.add')))
+        self.assertTrue(repr(self.app.tasks.NotRegistered('tasks.add')))
 
 
     def assertRegisterUnregisterCls(self, r, task):
     def assertRegisterUnregisterCls(self, r, task):
+        r.unregister(task)
         with self.assertRaises(r.NotRegistered):
         with self.assertRaises(r.NotRegistered):
             r.unregister(task)
             r.unregister(task)
         r.register(task)
         r.register(task)
@@ -58,35 +47,32 @@ class test_TaskRegistry(Case):
         self.assertIn(task_name, r)
         self.assertIn(task_name, r)
 
 
     def test_task_registry(self):
     def test_task_registry(self):
-        r = TaskRegistry()
+        r = self.app._tasks
         self.assertIsInstance(r, dict, 'TaskRegistry is mapping')
         self.assertIsInstance(r, dict, 'TaskRegistry is mapping')
 
 
-        self.assertRegisterUnregisterCls(r, MockTask)
-        self.assertRegisterUnregisterCls(r, MockPeriodicTask)
+        self.assertRegisterUnregisterCls(r, self.mytask)
+        self.assertRegisterUnregisterCls(r, self.myperiodic)
 
 
-        r.register(MockPeriodicTask)
-        r.unregister(MockPeriodicTask.name)
-        self.assertNotIn(MockPeriodicTask, r)
-        r.register(MockPeriodicTask)
+        r.register(self.myperiodic)
+        r.unregister(self.myperiodic.name)
+        self.assertNotIn(self.myperiodic, r)
+        r.register(self.myperiodic)
 
 
         tasks = dict(r)
         tasks = dict(r)
-        self.assertIsInstance(tasks.get(MockTask.name), MockTask)
-        self.assertIsInstance(tasks.get(MockPeriodicTask.name),
-                              MockPeriodicTask)
+        self.assertIs(tasks.get(self.mytask.name), self.mytask)
+        self.assertIs(tasks.get(self.myperiodic.name), self.myperiodic)
 
 
-        self.assertIsInstance(r[MockTask.name], MockTask)
-        self.assertIsInstance(r[MockPeriodicTask.name],
-                              MockPeriodicTask)
+        self.assertIs(r[self.mytask.name], self.mytask)
+        self.assertIs(r[self.myperiodic.name], self.myperiodic)
 
 
-        r.unregister(MockTask)
-        self.assertNotIn(MockTask.name, r)
-        r.unregister(MockPeriodicTask)
-        self.assertNotIn(MockPeriodicTask.name, r)
+        r.unregister(self.mytask)
+        self.assertNotIn(self.mytask.name, r)
+        r.unregister(self.myperiodic)
+        self.assertNotIn(self.myperiodic.name, r)
 
 
-        self.assertTrue(MockTask().run())
-        self.assertTrue(MockPeriodicTask().run())
+        self.assertTrue(self.mytask.run())
+        self.assertTrue(self.myperiodic.run())
 
 
     def test_compat(self):
     def test_compat(self):
-        r = TaskRegistry()
-        r.regular()
-        r.periodic()
+        self.assertTrue(self.app.tasks.regular())
+        self.assertTrue(self.app.tasks.periodic())

+ 62 - 70
celery/tests/app/test_routes.py

@@ -1,7 +1,5 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
 
 
-from contextlib import contextmanager
-
 from kombu import Exchange
 from kombu import Exchange
 from kombu.utils.functional import maybe_evaluate
 from kombu.utils.functional import maybe_evaluate
 
 
@@ -22,17 +20,9 @@ def E(app, queues):
     return expand
     return expand
 
 
 
 
-@contextmanager
-def _queues(app, **queues):
-    prev_queues = app.conf.CELERY_QUEUES
-    prev_Queues = app.amqp.queues
+def set_queues(app, **queues):
     app.conf.CELERY_QUEUES = queues
     app.conf.CELERY_QUEUES = queues
     app.amqp.queues = app.amqp.Queues(queues)
     app.amqp.queues = app.amqp.Queues(queues)
-    try:
-        yield
-    finally:
-        app.conf.CELERY_QUEUES = prev_queues
-        app.amqp.queues = prev_Queues
 
 
 
 
 class RouteCase(AppCase):
 class RouteCase(AppCase):
@@ -54,7 +44,7 @@ class RouteCase(AppCase):
             'routing_key': self.app.conf.CELERY_DEFAULT_ROUTING_KEY,
             'routing_key': self.app.conf.CELERY_DEFAULT_ROUTING_KEY,
         }
         }
 
 
-        @self.app.task()
+        @self.app.task(shared=False)
         def mytask():
         def mytask():
             pass
             pass
         self.mytask = mytask
         self.mytask = mytask
@@ -63,24 +53,24 @@ class RouteCase(AppCase):
 class test_MapRoute(RouteCase):
 class test_MapRoute(RouteCase):
 
 
     def test_route_for_task_expanded_route(self):
     def test_route_for_task_expanded_route(self):
-        with _queues(self.app, foo=self.a_queue, bar=self.b_queue):
-            expand = E(self.app, self.app.amqp.queues)
-            route = routes.MapRoute({self.mytask.name: {'queue': 'foo'}})
-            self.assertEqual(
-                expand(route.route_for_task(self.mytask.name))['queue'].name,
-                'foo',
-            )
-            self.assertIsNone(route.route_for_task('celery.awesome'))
+        set_queues(self.app, foo=self.a_queue, bar=self.b_queue)
+        expand = E(self.app, self.app.amqp.queues)
+        route = routes.MapRoute({self.mytask.name: {'queue': 'foo'}})
+        self.assertEqual(
+            expand(route.route_for_task(self.mytask.name))['queue'].name,
+            'foo',
+        )
+        self.assertIsNone(route.route_for_task('celery.awesome'))
 
 
     def test_route_for_task(self):
     def test_route_for_task(self):
-        with _queues(self.app, foo=self.a_queue, bar=self.b_queue):
-            expand = E(self.app, self.app.amqp.queues)
-            route = routes.MapRoute({self.mytask.name: self.b_queue})
-            self.assertDictContainsSubset(
-                self.b_queue,
-                expand(route.route_for_task(self.mytask.name)),
-            )
-            self.assertIsNone(route.route_for_task('celery.awesome'))
+        set_queues(self.app, foo=self.a_queue, bar=self.b_queue)
+        expand = E(self.app, self.app.amqp.queues)
+        route = routes.MapRoute({self.mytask.name: self.b_queue})
+        self.assertDictContainsSubset(
+            self.b_queue,
+            expand(route.route_for_task(self.mytask.name)),
+        )
+        self.assertIsNone(route.route_for_task('celery.awesome'))
 
 
     def test_expand_route_not_found(self):
     def test_expand_route_not_found(self):
         expand = E(self.app, self.app.amqp.Queues(
         expand = E(self.app, self.app.amqp.Queues(
@@ -97,54 +87,56 @@ class test_lookup_route(RouteCase):
         self.assertDictEqual(router.queues, {})
         self.assertDictEqual(router.queues, {})
 
 
     def test_lookup_takes_first(self):
     def test_lookup_takes_first(self):
-        with _queues(self.app, foo=self.a_queue, bar=self.b_queue):
-            R = routes.prepare(({self.mytask.name: {'queue': 'bar'}},
-                                {self.mytask.name: {'queue': 'foo'}}))
-            router = Router(self.app, R, self.app.amqp.queues)
-            self.assertEqual(router.route({}, self.mytask.name,
-                             args=[1, 2], kwargs={})['queue'].name, 'bar')
+        set_queues(self.app, foo=self.a_queue, bar=self.b_queue)
+        R = routes.prepare(({self.mytask.name: {'queue': 'bar'}},
+                            {self.mytask.name: {'queue': 'foo'}}))
+        router = Router(self.app, R, self.app.amqp.queues)
+        self.assertEqual(router.route({}, self.mytask.name,
+                         args=[1, 2], kwargs={})['queue'].name, 'bar')
 
 
     def test_expands_queue_in_options(self):
     def test_expands_queue_in_options(self):
-        with _queues(self.app):
-            R = routes.prepare(())
-            router = Router(
-                self.app, R, self.app.amqp.queues, create_missing=True,
-            )
-            # apply_async forwards all arguments, even exchange=None etc,
-            # so need to make sure it's merged correctly.
-            route = router.route(
-                {'queue': 'testq',
-                 'exchange': None,
-                 'routing_key': None,
-                 'immediate': False},
-                self.mytask.name,
-                args=[1, 2], kwargs={},
-            )
-            self.assertEqual(route['queue'].name, 'testq')
-            self.assertEqual(route['queue'].exchange, Exchange('testq'))
-            self.assertEqual(route['queue'].routing_key, 'testq')
-            self.assertEqual(route['immediate'], False)
+        set_queues(self.app)
+        R = routes.prepare(())
+        router = Router(
+            self.app, R, self.app.amqp.queues, create_missing=True,
+        )
+        # apply_async forwards all arguments, even exchange=None etc,
+        # so need to make sure it's merged correctly.
+        route = router.route(
+            {'queue': 'testq',
+             'exchange': None,
+             'routing_key': None,
+             'immediate': False},
+            self.mytask.name,
+            args=[1, 2], kwargs={},
+        )
+        self.assertEqual(route['queue'].name, 'testq')
+        self.assertEqual(route['queue'].exchange, Exchange('testq'))
+        self.assertEqual(route['queue'].routing_key, 'testq')
+        self.assertEqual(route['immediate'], False)
 
 
     def test_expand_destination_string(self):
     def test_expand_destination_string(self):
-        with _queues(self.app, foo=self.a_queue, bar=self.b_queue):
-            x = Router(self.app, {}, self.app.amqp.queues)
-            dest = x.expand_destination('foo')
-            self.assertEqual(dest['queue'].name, 'foo')
+        set_queues(self.app, foo=self.a_queue, bar=self.b_queue)
+        x = Router(self.app, {}, self.app.amqp.queues)
+        dest = x.expand_destination('foo')
+        self.assertEqual(dest['queue'].name, 'foo')
 
 
     def test_lookup_paths_traversed(self):
     def test_lookup_paths_traversed(self):
-        with _queues(self.app, foo=self.a_queue, bar=self.b_queue, **{
-                self.app.conf.CELERY_DEFAULT_QUEUE: self.d_queue}):
-            R = routes.prepare((
-                {'celery.xaza': {'queue': 'bar'}},
-                {self.mytask.name: {'queue': 'foo'}}
-            ))
-            router = Router(self.app, R, self.app.amqp.queues)
-            self.assertEqual(router.route({}, self.mytask.name,
-                             args=[1, 2], kwargs={})['queue'].name, 'foo')
-            self.assertEqual(
-                router.route({}, 'celery.poza')['queue'].name,
-                self.app.conf.CELERY_DEFAULT_QUEUE,
-            )
+        set_queues(
+            self.app, foo=self.a_queue, bar=self.b_queue,
+            **{self.app.conf.CELERY_DEFAULT_QUEUE: self.d_queue}
+        )
+        R = routes.prepare((
+            {'celery.xaza': {'queue': 'bar'}},
+            {self.mytask.name: {'queue': 'foo'}}
+        ))
+        router = Router(self.app, R, self.app.amqp.queues)
+        self.assertEqual(router.route({}, self.mytask.name,
+                         args=[1, 2], kwargs={})['queue'].name, 'foo')
+        self.assertEqual(
+            router.route({}, 'celery.poza')['queue'].name,
+            self.app.conf.CELERY_DEFAULT_QUEUE,
+        )
 
 
 
 
 class test_prepare(AppCase):
 class test_prepare(AppCase):

+ 2 - 2
celery/tests/app/test_schedules.py

@@ -365,8 +365,8 @@ class test_crontab_remaining_estimate(AppCase):
 
 
     def test_not_weekmonthdayyear(self):
     def test_not_weekmonthdayyear(self):
         next = self.next_ocurrance(
         next = self.next_ocurrance(
-            crontab(minute=[5, 42], day_of_week='fri,sat',
-                    day_of_month=29, month_of_year='2-10'),
+            self.crontab(minute=[5, 42], day_of_week='fri,sat',
+                         day_of_month=29, month_of_year='2-10'),
             datetime(2010, 1, 28, 14, 30, 15),
             datetime(2010, 1, 28, 14, 30, 15),
         )
         )
         self.assertEqual(next, datetime(2010, 5, 29, 0, 5))
         self.assertEqual(next, datetime(2010, 5, 29, 0, 5))

+ 2 - 2
celery/tests/app/test_utils.py

@@ -4,10 +4,10 @@ from collections import Mapping, MutableMapping
 
 
 from celery.app.utils import Settings, bugreport
 from celery.app.utils import Settings, bugreport
 
 
-from celery.tests.case import AppCase, Case, Mock
+from celery.tests.case import AppCase, Mock
 
 
 
 
-class TestSettings(Case):
+class TestSettings(AppCase):
     """
     """
     Tests of celery.app.utils.Settings
     Tests of celery.app.utils.Settings
     """
     """

+ 7 - 8
celery/tests/backends/test_amqp.py

@@ -16,7 +16,9 @@ from celery.exceptions import TimeoutError
 from celery.five import Empty, Queue, range
 from celery.five import Empty, Queue, range
 from celery.utils import uuid
 from celery.utils import uuid
 
 
-from celery.tests.case import AppCase, sleepdeprived, Mock
+from celery.tests.case import (
+    AppCase, Mock, depends_on_current_app, sleepdeprived,
+)
 
 
 
 
 class SomeClass(object):
 class SomeClass(object):
@@ -43,6 +45,7 @@ class test_AMQPBackend(AppCase):
         self.assertTrue(tb2._cache.get(tid))
         self.assertTrue(tb2._cache.get(tid))
         self.assertTrue(tb2.get_result(tid), 42)
         self.assertTrue(tb2.get_result(tid), 42)
 
 
+    @depends_on_current_app
     def test_pickleable(self):
     def test_pickleable(self):
         self.assertTrue(loads(dumps(self.create_backend())))
         self.assertTrue(loads(dumps(self.create_backend())))
 
 
@@ -322,14 +325,10 @@ class test_AMQPBackend(AppCase):
     def test_no_expires(self):
     def test_no_expires(self):
         b = self.create_backend(expires=None)
         b = self.create_backend(expires=None)
         app = self.app
         app = self.app
-        prev = app.conf.CELERY_TASK_RESULT_EXPIRES
         app.conf.CELERY_TASK_RESULT_EXPIRES = None
         app.conf.CELERY_TASK_RESULT_EXPIRES = None
-        try:
-            b = self.create_backend(expires=None)
-            with self.assertRaises(KeyError):
-                b.queue_arguments['x-expires']
-        finally:
-            app.conf.CELERY_TASK_RESULT_EXPIRES = prev
+        b = self.create_backend(expires=None)
+        with self.assertRaises(KeyError):
+            b.queue_arguments['x-expires']
 
 
     def test_process_cleanup(self):
     def test_process_cleanup(self):
         self.create_backend().process_cleanup()
         self.create_backend().process_cleanup()

+ 13 - 11
celery/tests/backends/test_backends.py

@@ -5,17 +5,19 @@ from mock import patch
 from celery import backends
 from celery import backends
 from celery.backends.amqp import AMQPBackend
 from celery.backends.amqp import AMQPBackend
 from celery.backends.cache import CacheBackend
 from celery.backends.cache import CacheBackend
-from celery.tests.case import AppCase
+from celery.tests.case import AppCase, depends_on_current_app
 
 
 
 
 class test_backends(AppCase):
 class test_backends(AppCase):
 
 
     def test_get_backend_aliases(self):
     def test_get_backend_aliases(self):
-        expects = [('amqp', AMQPBackend),
-                   ('cache', CacheBackend)]
-        for expect_name, expect_cls in expects:
+        expects = [('amqp://', AMQPBackend),
+                   ('cache+memory://', CacheBackend)]
+
+        for url, expect_cls in expects:
+            backend, url = backends.get_backend_by_url(url, self.app.loader)
             self.assertIsInstance(
             self.assertIsInstance(
-                backends.get_backend_cls(expect_name)(app=self.app),
+                backend(app=self.app, url=url),
                 expect_cls,
                 expect_cls,
             )
             )
 
 
@@ -23,22 +25,22 @@ class test_backends(AppCase):
         backends.get_backend_cls.clear()
         backends.get_backend_cls.clear()
         hits = backends.get_backend_cls.hits
         hits = backends.get_backend_cls.hits
         misses = backends.get_backend_cls.misses
         misses = backends.get_backend_cls.misses
-        self.assertTrue(backends.get_backend_cls('amqp'))
+        self.assertTrue(backends.get_backend_cls('amqp', self.app.loader))
         self.assertEqual(backends.get_backend_cls.misses, misses + 1)
         self.assertEqual(backends.get_backend_cls.misses, misses + 1)
-        self.assertTrue(backends.get_backend_cls('amqp'))
+        self.assertTrue(backends.get_backend_cls('amqp', self.app.loader))
         self.assertEqual(backends.get_backend_cls.hits, hits + 1)
         self.assertEqual(backends.get_backend_cls.hits, hits + 1)
 
 
     def test_unknown_backend(self):
     def test_unknown_backend(self):
         with self.assertRaises(ImportError):
         with self.assertRaises(ImportError):
-            backends.get_backend_cls('fasodaopjeqijwqe')
+            backends.get_backend_cls('fasodaopjeqijwqe', self.app.loader)
 
 
+    @depends_on_current_app
     def test_default_backend(self):
     def test_default_backend(self):
-        self.app.set_current()  # XXX compat test
         self.assertEqual(backends.default_backend, self.app.backend)
         self.assertEqual(backends.default_backend, self.app.backend)
 
 
     def test_backend_by_url(self, url='redis://localhost/1'):
     def test_backend_by_url(self, url='redis://localhost/1'):
         from celery.backends.redis import RedisBackend
         from celery.backends.redis import RedisBackend
-        backend, url_ = backends.get_backend_by_url(url)
+        backend, url_ = backends.get_backend_by_url(url, self.app.loader)
         self.assertIs(backend, RedisBackend)
         self.assertIs(backend, RedisBackend)
         self.assertEqual(url_, url)
         self.assertEqual(url_, url)
 
 
@@ -46,4 +48,4 @@ class test_backends(AppCase):
         with patch('celery.backends.symbol_by_name') as sbn:
         with patch('celery.backends.symbol_by_name') as sbn:
             sbn.side_effect = ValueError()
             sbn.side_effect = ValueError()
             with self.assertRaises(ValueError):
             with self.assertRaises(ValueError):
-                backends.get_backend_cls('xxx.xxx:foo')
+                backends.get_backend_cls('xxx.xxx:foo', self.app.loader)

+ 10 - 14
celery/tests/backends/test_base.py

@@ -9,7 +9,6 @@ from nose import SkipTest
 
 
 from celery.exceptions import ChordError
 from celery.exceptions import ChordError
 from celery.five import items, range
 from celery.five import items, range
-from celery.result import AsyncResult, GroupResult
 from celery.utils import serialization
 from celery.utils import serialization
 from celery.utils.serialization import subclass_exception
 from celery.utils.serialization import subclass_exception
 from celery.utils.serialization import find_pickleable_exception as fnpe
 from celery.utils.serialization import find_pickleable_exception as fnpe
@@ -24,7 +23,7 @@ from celery.backends.base import (
 )
 )
 from celery.utils import uuid
 from celery.utils import uuid
 
 
-from celery.tests.case import AppCase, Case
+from celery.tests.case import AppCase
 
 
 
 
 class wrapobject(object):
 class wrapobject(object):
@@ -66,18 +65,15 @@ class test_BaseBackend_interface(AppCase):
         self.b.on_chord_part_return(None)
         self.b.on_chord_part_return(None)
 
 
     def test_on_chord_apply(self, unlock='celery.chord_unlock'):
     def test_on_chord_apply(self, unlock='celery.chord_unlock'):
-        p, self.app.tasks[unlock] = self.app.tasks.get(unlock), Mock()
-        try:
-            self.b.on_chord_apply(
-                'dakj221', 'sdokqweok',
-                result=[self.app.AsyncResult(x) for x in [1, 2, 3]],
-            )
-            self.assertTrue(self.app.tasks[unlock].apply_async.call_count)
-        finally:
-            self.app.tasks[unlock] = p
+        self.app.tasks[unlock] = Mock()
+        self.b.on_chord_apply(
+            'dakj221', 'sdokqweok',
+            result=[self.app.AsyncResult(x) for x in [1, 2, 3]],
+        )
+        self.assertTrue(self.app.tasks[unlock].apply_async.call_count)
 
 
 
 
-class test_exception_pickle(Case):
+class test_exception_pickle(AppCase):
 
 
     def test_oldstyle(self):
     def test_oldstyle(self):
         if Oldstyle is None:
         if Oldstyle is None:
@@ -224,7 +220,7 @@ class test_BaseBackend_dict(AppCase):
             self.assertTrue(args[2])
             self.assertTrue(args[2])
 
 
     def test_prepare_value_serializes_group_result(self):
     def test_prepare_value_serializes_group_result(self):
-        g = GroupResult('group_id', [AsyncResult('foo')])
+        g = self.app.GroupResult('group_id', [self.app.AsyncResult('foo')])
         self.assertIsInstance(self.b.prepare_value(g), (list, tuple))
         self.assertIsInstance(self.b.prepare_value(g), (list, tuple))
 
 
     def test_is_cached(self):
     def test_is_cached(self):
@@ -286,7 +282,7 @@ class test_KeyValueStoreBackend(AppCase):
     @contextmanager
     @contextmanager
     def _chord_part_context(self, b):
     def _chord_part_context(self, b):
 
 
-        @self.app.task()
+        @self.app.task(shared=False)
         def callback(result):
         def callback(result):
             pass
             pass
 
 

+ 7 - 13
celery/tests/backends/test_cache.py

@@ -8,12 +8,11 @@ from contextlib import contextmanager
 from kombu.utils.encoding import str_to_bytes
 from kombu.utils.encoding import str_to_bytes
 from mock import Mock, patch
 from mock import Mock, patch
 
 
+from celery import subtask
 from celery import states
 from celery import states
 from celery.backends.cache import CacheBackend, DummyClient
 from celery.backends.cache import CacheBackend, DummyClient
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
 from celery.five import items, string, text_t
 from celery.five import items, string, text_t
-from celery.result import AsyncResult
-from celery.task import subtask
 from celery.utils import uuid
 from celery.utils import uuid
 
 
 from celery.tests.case import AppCase, mask_modules, reset_modules
 from celery.tests.case import AppCase, mask_modules, reset_modules
@@ -32,14 +31,9 @@ class test_CacheBackend(AppCase):
         self.tid = uuid()
         self.tid = uuid()
 
 
     def test_no_backend(self):
     def test_no_backend(self):
-        prev, self.app.conf.CELERY_CACHE_BACKEND = (
-            self.app.conf.CELERY_CACHE_BACKEND, None,
-        )
-        try:
-            with self.assertRaises(ImproperlyConfigured):
-                CacheBackend(backend=None, app=self.app)
-        finally:
-            self.app.conf.CELERY_CACHE_BACKEND = prev
+        self.app.conf.CELERY_CACHE_BACKEND = None
+        with self.assertRaises(ImproperlyConfigured):
+            CacheBackend(backend=None, app=self.app)
 
 
     def test_mark_as_done(self):
     def test_mark_as_done(self):
         self.assertEqual(self.tb.get_status(self.tid), states.PENDING)
         self.assertEqual(self.tb.get_status(self.tid), states.PENDING)
@@ -67,7 +61,7 @@ class test_CacheBackend(AppCase):
 
 
     def test_on_chord_apply(self):
     def test_on_chord_apply(self):
         tb = CacheBackend(backend='memory://', app=self.app)
         tb = CacheBackend(backend='memory://', app=self.app)
-        gid, res = uuid(), [AsyncResult(uuid()) for _ in range(3)]
+        gid, res = uuid(), [self.app.AsyncResult(uuid()) for _ in range(3)]
         tb.on_chord_apply(gid, {}, result=res)
         tb.on_chord_apply(gid, {}, result=res)
 
 
     @patch('celery.result.GroupResult.restore')
     @patch('celery.result.GroupResult.restore')
@@ -83,7 +77,7 @@ class test_CacheBackend(AppCase):
         self.app.tasks['foobarbaz'] = task
         self.app.tasks['foobarbaz'] = task
         task.request.chord = subtask(task)
         task.request.chord = subtask(task)
 
 
-        gid, res = uuid(), [AsyncResult(uuid()) for _ in range(3)]
+        gid, res = uuid(), [self.app.AsyncResult(uuid()) for _ in range(3)]
         task.request.group = gid
         task.request.group = gid
         tb.on_chord_apply(gid, {}, result=res)
         tb.on_chord_apply(gid, {}, result=res)
 
 
@@ -104,7 +98,7 @@ class test_CacheBackend(AppCase):
 
 
     def test_forget(self):
     def test_forget(self):
         self.tb.mark_as_done(self.tid, {'foo': 'bar'})
         self.tb.mark_as_done(self.tid, {'foo': 'bar'})
-        x = AsyncResult(self.tid, backend=self.tb)
+        x = self.app.AsyncResult(self.tid, backend=self.tb)
         x.forget()
         x.forget()
         self.assertIsNone(x.result)
         self.assertIsNone(x.result)
 
 

+ 22 - 25
celery/tests/backends/test_cassandra.py

@@ -5,10 +5,9 @@ import socket
 from mock import Mock
 from mock import Mock
 from pickle import loads, dumps
 from pickle import loads, dumps
 
 
-from celery import Celery
 from celery import states
 from celery import states
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
-from celery.tests.case import AppCase, mock_module
+from celery.tests.case import AppCase, mock_module, depends_on_current_app
 
 
 
 
 class Object(object):
 class Object(object):
@@ -46,6 +45,13 @@ def install_exceptions(mod):
 
 
 class test_CassandraBackend(AppCase):
 class test_CassandraBackend(AppCase):
 
 
+    def setup(self):
+        self.app.conf.update(
+            CASSANDRA_SERVERS=['example.com'],
+            CASSANDRA_KEYSPACE='keyspace',
+            CASSANDRA_COLUMN_FAMILY='columns',
+        )
+
     def test_init_no_pycassa(self):
     def test_init_no_pycassa(self):
         with mock_module('pycassa'):
         with mock_module('pycassa'):
             from celery.backends import cassandra as mod
             from celery.backends import cassandra as mod
@@ -56,13 +62,6 @@ class test_CassandraBackend(AppCase):
             finally:
             finally:
                 mod.pycassa = prev
                 mod.pycassa = prev
 
 
-    def get_app(self):
-        celery = Celery(set_as_current=False)
-        celery.conf.CASSANDRA_SERVERS = ['example.com']
-        celery.conf.CASSANDRA_KEYSPACE = 'keyspace'
-        celery.conf.CASSANDRA_COLUMN_FAMILY = 'columns'
-        return celery
-
     def test_init_with_and_without_LOCAL_QUROM(self):
     def test_init_with_and_without_LOCAL_QUROM(self):
         with mock_module('pycassa'):
         with mock_module('pycassa'):
             from celery.backends import cassandra as mod
             from celery.backends import cassandra as mod
@@ -71,23 +70,25 @@ class test_CassandraBackend(AppCase):
             cons = mod.pycassa.ConsistencyLevel = Object()
             cons = mod.pycassa.ConsistencyLevel = Object()
             cons.LOCAL_QUORUM = 'foo'
             cons.LOCAL_QUORUM = 'foo'
 
 
-            app = self.get_app()
-            app.conf.CASSANDRA_READ_CONSISTENCY = 'LOCAL_FOO'
-            app.conf.CASSANDRA_WRITE_CONSISTENCY = 'LOCAL_FOO'
+            self.app.conf.CASSANDRA_READ_CONSISTENCY = 'LOCAL_FOO'
+            self.app.conf.CASSANDRA_WRITE_CONSISTENCY = 'LOCAL_FOO'
 
 
-            mod.CassandraBackend(app=app)
+            mod.CassandraBackend(app=self.app)
             cons.LOCAL_FOO = 'bar'
             cons.LOCAL_FOO = 'bar'
-            mod.CassandraBackend(app=app)
+            mod.CassandraBackend(app=self.app)
 
 
             # no servers raises ImproperlyConfigured
             # no servers raises ImproperlyConfigured
             with self.assertRaises(ImproperlyConfigured):
             with self.assertRaises(ImproperlyConfigured):
-                app.conf.CASSANDRA_SERVERS = None
-                mod.CassandraBackend(app=app, keyspace='b', column_family='c')
+                self.app.conf.CASSANDRA_SERVERS = None
+                mod.CassandraBackend(
+                    app=self.app, keyspace='b', column_family='c',
+                )
 
 
+    @depends_on_current_app
     def test_reduce(self):
     def test_reduce(self):
         with mock_module('pycassa'):
         with mock_module('pycassa'):
             from celery.backends.cassandra import CassandraBackend
             from celery.backends.cassandra import CassandraBackend
-            self.assertTrue(loads(dumps(CassandraBackend(app=self.get_app()))))
+            self.assertTrue(loads(dumps(CassandraBackend(app=self.app))))
 
 
     def test_get_task_meta_for(self):
     def test_get_task_meta_for(self):
         with mock_module('pycassa'):
         with mock_module('pycassa'):
@@ -96,8 +97,7 @@ class test_CassandraBackend(AppCase):
             install_exceptions(mod.pycassa)
             install_exceptions(mod.pycassa)
             mod.Thrift = Mock()
             mod.Thrift = Mock()
             install_exceptions(mod.Thrift)
             install_exceptions(mod.Thrift)
-            app = self.get_app()
-            x = mod.CassandraBackend(app=app)
+            x = mod.CassandraBackend(app=self.app)
             Get_Column = x._get_column_family = Mock()
             Get_Column = x._get_column_family = Mock()
             get_column = Get_Column.return_value = Mock()
             get_column = Get_Column.return_value = Mock()
             get = get_column.get
             get = get_column.get
@@ -155,8 +155,7 @@ class test_CassandraBackend(AppCase):
             install_exceptions(mod.pycassa)
             install_exceptions(mod.pycassa)
             mod.Thrift = Mock()
             mod.Thrift = Mock()
             install_exceptions(mod.Thrift)
             install_exceptions(mod.Thrift)
-            app = self.get_app()
-            x = mod.CassandraBackend(app=app)
+            x = mod.CassandraBackend(app=self.app)
             Get_Column = x._get_column_family = Mock()
             Get_Column = x._get_column_family = Mock()
             cf = Get_Column.return_value = Mock()
             cf = Get_Column.return_value = Mock()
             x.detailed_mode = False
             x.detailed_mode = False
@@ -171,8 +170,7 @@ class test_CassandraBackend(AppCase):
     def test_process_cleanup(self):
     def test_process_cleanup(self):
         with mock_module('pycassa'):
         with mock_module('pycassa'):
             from celery.backends import cassandra as mod
             from celery.backends import cassandra as mod
-            app = self.get_app()
-            x = mod.CassandraBackend(app=app)
+            x = mod.CassandraBackend(app=self.app)
             x._column_family = None
             x._column_family = None
             x.process_cleanup()
             x.process_cleanup()
 
 
@@ -185,8 +183,7 @@ class test_CassandraBackend(AppCase):
             from celery.backends import cassandra as mod
             from celery.backends import cassandra as mod
             mod.pycassa = Mock()
             mod.pycassa = Mock()
             install_exceptions(mod.pycassa)
             install_exceptions(mod.pycassa)
-            app = self.get_app()
-            x = mod.CassandraBackend(app=app)
+            x = mod.CassandraBackend(app=self.app)
             self.assertTrue(x._get_column_family())
             self.assertTrue(x._get_column_family())
             self.assertIsNotNone(x._column_family)
             self.assertIsNotNone(x._column_family)
             self.assertIs(x._get_column_family(), x._column_family)
             self.assertIs(x._get_column_family(), x._column_family)

+ 46 - 58
celery/tests/backends/test_couchbase.py

@@ -3,7 +3,6 @@ from __future__ import absolute_import
 from mock import MagicMock, Mock, patch, sentinel
 from mock import MagicMock, Mock, patch, sentinel
 from nose import SkipTest
 from nose import SkipTest
 
 
-from celery import Celery
 from celery.backends import couchbase as module
 from celery.backends import couchbase as module
 from celery.backends.couchbase import CouchBaseBackend
 from celery.backends.couchbase import CouchBaseBackend
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
@@ -36,19 +35,16 @@ class test_CouchBaseBackend(AppCase):
 
 
     def test_init_no_settings(self):
     def test_init_no_settings(self):
         """test init no settings"""
         """test init no settings"""
-        celery = Celery(set_as_current=False)
-        celery.conf.CELERY_COUCHBASE_BACKEND_SETTINGS = []
+        self.app.conf.CELERY_COUCHBASE_BACKEND_SETTINGS = []
         with self.assertRaises(ImproperlyConfigured):
         with self.assertRaises(ImproperlyConfigured):
-            CouchBaseBackend(app=celery)
+            CouchBaseBackend(app=self.app)
 
 
     def test_init_settings_is_None(self):
     def test_init_settings_is_None(self):
         """Test init settings is None"""
         """Test init settings is None"""
-        celery = Celery(set_as_current=False)
-        celery.conf.CELERY_COUCHBASE_BACKEND_SETTINGS = None
-        CouchBaseBackend(app=celery)
+        self.app.conf.CELERY_COUCHBASE_BACKEND_SETTINGS = None
+        CouchBaseBackend(app=self.app)
 
 
     def test_get_connection_connection_exists(self):
     def test_get_connection_connection_exists(self):
-        """Test get existing connection"""
         with patch('couchbase.connection.Connection') as mock_Connection:
         with patch('couchbase.connection.Connection') as mock_Connection:
             self.backend._connection = sentinel._connection
             self.backend._connection = sentinel._connection
 
 
@@ -58,89 +54,81 @@ class test_CouchBaseBackend(AppCase):
             self.assertFalse(mock_Connection.called)
             self.assertFalse(mock_Connection.called)
 
 
     def test_get(self):
     def test_get(self):
-        """Test get
+        """test_get
 
 
         CouchBaseBackend.get should return  and take two params
         CouchBaseBackend.get should return  and take two params
         db conn to couchbase is mocked.
         db conn to couchbase is mocked.
         TODO Should test on key not exists
         TODO Should test on key not exists
 
 
         """
         """
-        with Celery(set_as_current=False) as app:
-            app.conf.CELERY_COUCHBASE_BACKEND_SETTINGS = {}
-
-            x = CouchBaseBackend(app=app)
-            x._connection = Mock()
-            mocked_get = x._connection.get = Mock()
-            mocked_get.return_value.value = sentinel.retval
-            # should return None
-            self.assertEqual(x.get('1f3fab'), sentinel.retval)
-            x._connection.get.assert_called_once_with('1f3fab')
-
-    # betta
+        self.app.conf.CELERY_COUCHBASE_BACKEND_SETTINGS = {}
+        x = CouchBaseBackend(app=self.app)
+        x._connection = Mock()
+        mocked_get = x._connection.get = Mock()
+        mocked_get.return_value.value = sentinel.retval
+        # should return None
+        self.assertEqual(x.get('1f3fab'), sentinel.retval)
+        x._connection.get.assert_called_once_with('1f3fab')
+
     def test_set(self):
     def test_set(self):
-        """Test set
+        """test_set
 
 
         CouchBaseBackend.set should return None and take two params
         CouchBaseBackend.set should return None and take two params
         db conn to couchbase is mocked.
         db conn to couchbase is mocked.
 
 
         """
         """
-        with Celery(set_as_current=False) as app:
-            app.conf.CELERY_COUCHBASE_BACKEND_SETTINGS = None
-            x = CouchBaseBackend(app=app)
-            x._connection = MagicMock()
-            x._connection.set = MagicMock()
-            # should return None
-            self.assertIsNone(x.set(sentinel.key, sentinel.value))
+        self.app.conf.CELERY_COUCHBASE_BACKEND_SETTINGS = None
+        x = CouchBaseBackend(app=self.app)
+        x._connection = MagicMock()
+        x._connection.set = MagicMock()
+        # should return None
+        self.assertIsNone(x.set(sentinel.key, sentinel.value))
 
 
     def test_delete(self):
     def test_delete(self):
-        """Test delete
+        """test_delete
 
 
         CouchBaseBackend.delete should return and take two params
         CouchBaseBackend.delete should return and take two params
         db conn to couchbase is mocked.
         db conn to couchbase is mocked.
         TODO Should test on key not exists
         TODO Should test on key not exists
 
 
         """
         """
-        with Celery(set_as_current=False) as app:
-            app.conf.CELERY_COUCHBASE_BACKEND_SETTINGS = {}
-            x = CouchBaseBackend(app=app)
-            x._connection = Mock()
-            mocked_delete = x._connection.delete = Mock()
-            mocked_delete.return_value = None
-            # should return None
-            self.assertIsNone(x.delete('1f3fab'))
-            x._connection.delete.assert_called_once_with('1f3fab')
+        self.app.conf.CELERY_COUCHBASE_BACKEND_SETTINGS = {}
+        x = CouchBaseBackend(app=self.app)
+        x._connection = Mock()
+        mocked_delete = x._connection.delete = Mock()
+        mocked_delete.return_value = None
+        # should return None
+        self.assertIsNone(x.delete('1f3fab'))
+        x._connection.delete.assert_called_once_with('1f3fab')
 
 
     def test_config_params(self):
     def test_config_params(self):
-        """test celery.conf.CELERY_COUCHBASE_BACKEND_SETTINGS
+        """test_config_params
 
 
         celery.conf.CELERY_COUCHBASE_BACKEND_SETTINGS is properly set
         celery.conf.CELERY_COUCHBASE_BACKEND_SETTINGS is properly set
         """
         """
-        with Celery(set_as_current=False) as app:
-            app.conf.CELERY_COUCHBASE_BACKEND_SETTINGS = {
-                'bucket': 'mycoolbucket',
-                'host': ['here.host.com', 'there.host.com'],
-                'username': 'johndoe',
-                'password': 'mysecret',
-                'port': '1234',
-            }
-            x = CouchBaseBackend(app=app)
-            self.assertEqual(x.bucket, 'mycoolbucket')
-            self.assertEqual(x.host, ['here.host.com', 'there.host.com'],)
-            self.assertEqual(x.username, 'johndoe',)
-            self.assertEqual(x.password, 'mysecret')
-            self.assertEqual(x.port, 1234)
+        self.app.conf.CELERY_COUCHBASE_BACKEND_SETTINGS = {
+            'bucket': 'mycoolbucket',
+            'host': ['here.host.com', 'there.host.com'],
+            'username': 'johndoe',
+            'password': 'mysecret',
+            'port': '1234',
+        }
+        x = CouchBaseBackend(app=self.app)
+        self.assertEqual(x.bucket, 'mycoolbucket')
+        self.assertEqual(x.host, ['here.host.com', 'there.host.com'],)
+        self.assertEqual(x.username, 'johndoe',)
+        self.assertEqual(x.password, 'mysecret')
+        self.assertEqual(x.port, 1234)
 
 
     def test_backend_by_url(self, url='couchbase://myhost/mycoolbucket'):
     def test_backend_by_url(self, url='couchbase://myhost/mycoolbucket'):
-        """test get backend by url"""
         from celery.backends.couchbase import CouchBaseBackend
         from celery.backends.couchbase import CouchBaseBackend
-        backend, url_ = backends.get_backend_by_url(url)
+        backend, url_ = backends.get_backend_by_url(url, self.app.loader)
         self.assertIs(backend, CouchBaseBackend)
         self.assertIs(backend, CouchBaseBackend)
         self.assertEqual(url_, url)
         self.assertEqual(url_, url)
 
 
     def test_backend_params_by_url(self):
     def test_backend_params_by_url(self):
-        """test get backend params by url"""
         url = 'couchbase://johndoe:mysecret@myhost:123/mycoolbucket'
         url = 'couchbase://johndoe:mysecret@myhost:123/mycoolbucket'
-        with Celery(set_as_current=False, backend=url) as app:
+        with self.Celery(backend=url) as app:
             x = app.backend
             x = app.backend
             self.assertEqual(x.bucket, "mycoolbucket")
             self.assertEqual(x.bucket, "mycoolbucket")
             self.assertEqual(x.host, "myhost")
             self.assertEqual(x.host, "myhost")

+ 20 - 22
celery/tests/backends/test_database.py

@@ -7,11 +7,11 @@ from pickle import loads, dumps
 
 
 from celery import states
 from celery import states
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
-from celery.result import AsyncResult
 from celery.utils import uuid
 from celery.utils import uuid
 
 
 from celery.tests.case import (
 from celery.tests.case import (
     AppCase,
     AppCase,
+    depends_on_current_app,
     mask_modules,
     mask_modules,
     skip_if_pypy,
     skip_if_pypy,
     skip_if_jython,
     skip_if_jython,
@@ -39,6 +39,7 @@ class test_DatabaseBackend(AppCase):
     def setup(self):
     def setup(self):
         if DatabaseBackend is None:
         if DatabaseBackend is None:
             raise SkipTest('sqlalchemy not installed')
             raise SkipTest('sqlalchemy not installed')
+        self.uri = 'sqlite:///test.db'
 
 
     def test_retry_helper(self):
     def test_retry_helper(self):
         from celery.backends.database import OperationalError
         from celery.backends.database import OperationalError
@@ -61,20 +62,16 @@ class test_DatabaseBackend(AppCase):
                 _sqlalchemy_installed()
                 _sqlalchemy_installed()
 
 
     def test_missing_dburi_raises_ImproperlyConfigured(self):
     def test_missing_dburi_raises_ImproperlyConfigured(self):
-        conf = self.app.conf
-        prev, conf.CELERY_RESULT_DBURI = conf.CELERY_RESULT_DBURI, None
-        try:
-            with self.assertRaises(ImproperlyConfigured):
-                DatabaseBackend(app=self.app)
-        finally:
-            conf.CELERY_RESULT_DBURI = prev
+        self.app.conf.CELERY_RESULT_DBURI = None
+        with self.assertRaises(ImproperlyConfigured):
+            DatabaseBackend(app=self.app)
 
 
     def test_missing_task_id_is_PENDING(self):
     def test_missing_task_id_is_PENDING(self):
-        tb = DatabaseBackend(app=self.app)
+        tb = DatabaseBackend(self.uri, app=self.app)
         self.assertEqual(tb.get_status('xxx-does-not-exist'), states.PENDING)
         self.assertEqual(tb.get_status('xxx-does-not-exist'), states.PENDING)
 
 
     def test_missing_task_meta_is_dict_with_pending(self):
     def test_missing_task_meta_is_dict_with_pending(self):
-        tb = DatabaseBackend(app=self.app)
+        tb = DatabaseBackend(self.uri, app=self.app)
         self.assertDictContainsSubset({
         self.assertDictContainsSubset({
             'status': states.PENDING,
             'status': states.PENDING,
             'task_id': 'xxx-does-not-exist-at-all',
             'task_id': 'xxx-does-not-exist-at-all',
@@ -83,7 +80,7 @@ class test_DatabaseBackend(AppCase):
         }, tb.get_task_meta('xxx-does-not-exist-at-all'))
         }, tb.get_task_meta('xxx-does-not-exist-at-all'))
 
 
     def test_mark_as_done(self):
     def test_mark_as_done(self):
-        tb = DatabaseBackend(app=self.app)
+        tb = DatabaseBackend(self.uri, app=self.app)
 
 
         tid = uuid()
         tid = uuid()
 
 
@@ -95,7 +92,7 @@ class test_DatabaseBackend(AppCase):
         self.assertEqual(tb.get_result(tid), 42)
         self.assertEqual(tb.get_result(tid), 42)
 
 
     def test_is_pickled(self):
     def test_is_pickled(self):
-        tb = DatabaseBackend(app=self.app)
+        tb = DatabaseBackend(self.uri, app=self.app)
 
 
         tid2 = uuid()
         tid2 = uuid()
         result = {'foo': 'baz', 'bar': SomeClass(12345)}
         result = {'foo': 'baz', 'bar': SomeClass(12345)}
@@ -106,19 +103,19 @@ class test_DatabaseBackend(AppCase):
         self.assertEqual(rindb.get('bar').data, 12345)
         self.assertEqual(rindb.get('bar').data, 12345)
 
 
     def test_mark_as_started(self):
     def test_mark_as_started(self):
-        tb = DatabaseBackend(app=self.app)
+        tb = DatabaseBackend(self.uri, app=self.app)
         tid = uuid()
         tid = uuid()
         tb.mark_as_started(tid)
         tb.mark_as_started(tid)
         self.assertEqual(tb.get_status(tid), states.STARTED)
         self.assertEqual(tb.get_status(tid), states.STARTED)
 
 
     def test_mark_as_revoked(self):
     def test_mark_as_revoked(self):
-        tb = DatabaseBackend(app=self.app)
+        tb = DatabaseBackend(self.uri, app=self.app)
         tid = uuid()
         tid = uuid()
         tb.mark_as_revoked(tid)
         tb.mark_as_revoked(tid)
         self.assertEqual(tb.get_status(tid), states.REVOKED)
         self.assertEqual(tb.get_status(tid), states.REVOKED)
 
 
     def test_mark_as_retry(self):
     def test_mark_as_retry(self):
-        tb = DatabaseBackend(app=self.app)
+        tb = DatabaseBackend(self.uri, app=self.app)
         tid = uuid()
         tid = uuid()
         try:
         try:
             raise KeyError('foo')
             raise KeyError('foo')
@@ -131,7 +128,7 @@ class test_DatabaseBackend(AppCase):
             self.assertEqual(tb.get_traceback(tid), trace)
             self.assertEqual(tb.get_traceback(tid), trace)
 
 
     def test_mark_as_failure(self):
     def test_mark_as_failure(self):
-        tb = DatabaseBackend(app=self.app)
+        tb = DatabaseBackend(self.uri, app=self.app)
 
 
         tid3 = uuid()
         tid3 = uuid()
         try:
         try:
@@ -145,24 +142,25 @@ class test_DatabaseBackend(AppCase):
             self.assertEqual(tb.get_traceback(tid3), trace)
             self.assertEqual(tb.get_traceback(tid3), trace)
 
 
     def test_forget(self):
     def test_forget(self):
-        tb = DatabaseBackend(backend='memory://', app=self.app)
+        tb = DatabaseBackend(self.uri, backend='memory://', app=self.app)
         tid = uuid()
         tid = uuid()
         tb.mark_as_done(tid, {'foo': 'bar'})
         tb.mark_as_done(tid, {'foo': 'bar'})
         tb.mark_as_done(tid, {'foo': 'bar'})
         tb.mark_as_done(tid, {'foo': 'bar'})
-        x = AsyncResult(tid, backend=tb)
+        x = self.app.AsyncResult(tid, backend=tb)
         x.forget()
         x.forget()
         self.assertIsNone(x.result)
         self.assertIsNone(x.result)
 
 
     def test_process_cleanup(self):
     def test_process_cleanup(self):
-        tb = DatabaseBackend(app=self.app)
+        tb = DatabaseBackend(self.uri, app=self.app)
         tb.process_cleanup()
         tb.process_cleanup()
 
 
+    @depends_on_current_app
     def test_reduce(self):
     def test_reduce(self):
-        tb = DatabaseBackend(app=self.app)
+        tb = DatabaseBackend(self.uri, app=self.app)
         self.assertTrue(loads(dumps(tb)))
         self.assertTrue(loads(dumps(tb)))
 
 
     def test_save__restore__delete_group(self):
     def test_save__restore__delete_group(self):
-        tb = DatabaseBackend(app=self.app)
+        tb = DatabaseBackend(self.uri, app=self.app)
 
 
         tid = uuid()
         tid = uuid()
         res = {'something': 'special'}
         res = {'something': 'special'}
@@ -177,7 +175,7 @@ class test_DatabaseBackend(AppCase):
         self.assertIsNone(tb.restore_group('xxx-nonexisting-id'))
         self.assertIsNone(tb.restore_group('xxx-nonexisting-id'))
 
 
     def test_cleanup(self):
     def test_cleanup(self):
-        tb = DatabaseBackend(app=self.app)
+        tb = DatabaseBackend(self.uri, app=self.app)
         for i in range(10):
         for i in range(10):
             tb.mark_as_done(uuid(), 42)
             tb.mark_as_done(uuid(), 42)
             tb.save_group(uuid(), {'foo': 'bar'})
             tb.save_group(uuid(), {'foo': 'bar'})

+ 6 - 8
celery/tests/backends/test_mongodb.py

@@ -7,12 +7,11 @@ from mock import MagicMock, Mock, patch, sentinel
 from nose import SkipTest
 from nose import SkipTest
 from pickle import loads, dumps
 from pickle import loads, dumps
 
 
-from celery import Celery
 from celery import states
 from celery import states
 from celery.backends import mongodb as module
 from celery.backends import mongodb as module
 from celery.backends.mongodb import MongoBackend, Bunch, pymongo
 from celery.backends.mongodb import MongoBackend, Bunch, pymongo
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
-from celery.tests.case import AppCase
+from celery.tests.case import AppCase, depends_on_current_app
 
 
 COLLECTION = 'taskmeta_celery'
 COLLECTION = 'taskmeta_celery'
 TASK_ID = str(uuid.uuid1())
 TASK_ID = str(uuid.uuid1())
@@ -58,15 +57,13 @@ class test_MongoBackend(AppCase):
             module.pymongo = prev
             module.pymongo = prev
 
 
     def test_init_no_settings(self):
     def test_init_no_settings(self):
-        celery = Celery(set_as_current=False)
-        celery.conf.CELERY_MONGODB_BACKEND_SETTINGS = []
+        self.app.conf.CELERY_MONGODB_BACKEND_SETTINGS = []
         with self.assertRaises(ImproperlyConfigured):
         with self.assertRaises(ImproperlyConfigured):
-            MongoBackend(app=celery)
+            MongoBackend(app=self.app)
 
 
     def test_init_settings_is_None(self):
     def test_init_settings_is_None(self):
-        celery = Celery(set_as_current=False)
-        celery.conf.CELERY_MONGODB_BACKEND_SETTINGS = None
-        MongoBackend(app=celery)
+        self.app.conf.CELERY_MONGODB_BACKEND_SETTINGS = None
+        MongoBackend(app=self.app)
 
 
     def test_restore_group_no_entry(self):
     def test_restore_group_no_entry(self):
         x = MongoBackend(app=self.app)
         x = MongoBackend(app=self.app)
@@ -75,6 +72,7 @@ class test_MongoBackend(AppCase):
         fo.return_value = None
         fo.return_value = None
         self.assertIsNone(x._restore_group('1f3fab'))
         self.assertIsNone(x._restore_group('1f3fab'))
 
 
+    @depends_on_current_app
     def test_reduce(self):
     def test_reduce(self):
         x = MongoBackend(app=self.app)
         x = MongoBackend(app=self.app)
         self.assertTrue(loads(dumps(x)))
         self.assertTrue(loads(dumps(x)))

+ 14 - 21
celery/tests/backends/test_redis.py

@@ -8,14 +8,13 @@ from pickle import loads, dumps
 
 
 from kombu.utils import cached_property, uuid
 from kombu.utils import cached_property, uuid
 
 
+from celery import subtask
 from celery import states
 from celery import states
 from celery.datastructures import AttributeDict
 from celery.datastructures import AttributeDict
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
-from celery.result import AsyncResult
-from celery.task import subtask
 from celery.utils.timeutils import timedelta_seconds
 from celery.utils.timeutils import timedelta_seconds
 
 
-from celery.tests.case import AppCase
+from celery.tests.case import AppCase, depends_on_current_app
 
 
 
 
 class Redis(object):
 class Redis(object):
@@ -85,6 +84,7 @@ class test_RedisBackend(AppCase):
 
 
         self.MockBackend = MockBackend
         self.MockBackend = MockBackend
 
 
+    @depends_on_current_app
     def test_reduce(self):
     def test_reduce(self):
         try:
         try:
             from celery.backends.redis import RedisBackend
             from celery.backends.redis import RedisBackend
@@ -104,25 +104,18 @@ class test_RedisBackend(AppCase):
         self.assertEqual(x.db, '1')
         self.assertEqual(x.db, '1')
 
 
     def test_conf_raises_KeyError(self):
     def test_conf_raises_KeyError(self):
-        conf = AttributeDict({'CELERY_RESULT_SERIALIZER': 'json',
-                              'CELERY_MAX_CACHED_RESULTS': 1,
-                              'CELERY_ACCEPT_CONTENT': ['json'],
-                              'CELERY_TASK_RESULT_EXPIRES': None})
-        prev, self.app.conf = self.app.conf, conf
-        try:
-            self.MockBackend(app=self.app)
-        finally:
-            self.app.conf = prev
+        self.app.conf = AttributeDict({
+            'CELERY_RESULT_SERIALIZER': 'json',
+            'CELERY_MAX_CACHED_RESULTS': 1,
+            'CELERY_ACCEPT_CONTENT': ['json'],
+            'CELERY_TASK_RESULT_EXPIRES': None,
+        })
+        self.MockBackend(app=self.app)
 
 
     def test_expires_defaults_to_config(self):
     def test_expires_defaults_to_config(self):
-        conf = self.app.conf
-        prev = conf.CELERY_TASK_RESULT_EXPIRES
-        conf.CELERY_TASK_RESULT_EXPIRES = 10
-        try:
-            b = self.Backend(expires=None, app=self.app)
-            self.assertEqual(b.expires, 10)
-        finally:
-            conf.CELERY_TASK_RESULT_EXPIRES = prev
+        self.app.conf.CELERY_TASK_RESULT_EXPIRES = 10
+        b = self.Backend(expires=None, app=self.app)
+        self.assertEqual(b.expires, 10)
 
 
     def test_expires_is_int(self):
     def test_expires_is_int(self):
         b = self.Backend(expires=48, app=self.app)
         b = self.Backend(expires=48, app=self.app)
@@ -140,7 +133,7 @@ class test_RedisBackend(AppCase):
     def test_on_chord_apply(self):
     def test_on_chord_apply(self):
         self.Backend(app=self.app).on_chord_apply(
         self.Backend(app=self.app).on_chord_apply(
             'group_id', {},
             'group_id', {},
-            result=[AsyncResult(x) for x in [1, 2, 3]],
+            result=[self.app.AsyncResult(x) for x in [1, 2, 3]],
         )
         )
 
 
     def test_mget(self):
     def test_mget(self):

+ 0 - 2
celery/tests/bin/test_amqp.py

@@ -2,7 +2,6 @@ from __future__ import absolute_import
 
 
 from mock import Mock, patch
 from mock import Mock, patch
 
 
-from celery import Celery
 from celery.bin.amqp import (
 from celery.bin.amqp import (
     AMQPAdmin,
     AMQPAdmin,
     AMQShell,
     AMQShell,
@@ -18,7 +17,6 @@ class test_AMQShell(AppCase):
 
 
     def setup(self):
     def setup(self):
         self.fh = WhateverIO()
         self.fh = WhateverIO()
-        self.app = Celery(broker='memory://', set_as_current=False)
         self.adm = self.create_adm()
         self.adm = self.create_adm()
         self.shell = AMQShell(connect=self.adm.connect, out=self.fh)
         self.shell = AMQShell(connect=self.adm.connect, out=self.fh)
 
 

+ 16 - 16
celery/tests/bin/test_base.py

@@ -10,7 +10,9 @@ from celery.bin.base import (
     Extensions,
     Extensions,
     HelpFormatter,
     HelpFormatter,
 )
 )
-from celery.tests.case import AppCase, Case, override_stdouts
+from celery.tests.case import (
+    AppCase, override_stdouts, depends_on_current_app,
+)
 
 
 
 
 class Object(object):
 class Object(object):
@@ -36,7 +38,7 @@ class MockCommand(Command):
         return args, kwargs
         return args, kwargs
 
 
 
 
-class test_Extensions(Case):
+class test_Extensions(AppCase):
 
 
     def test_load(self):
     def test_load(self):
         with patch('pkg_resources.iter_entry_points') as iterep:
         with patch('pkg_resources.iter_entry_points') as iterep:
@@ -65,7 +67,7 @@ class test_Extensions(Case):
                     e.load()
                     e.load()
 
 
 
 
-class test_HelpFormatter(Case):
+class test_HelpFormatter(AppCase):
 
 
     def test_format_epilog(self):
     def test_format_epilog(self):
         f = HelpFormatter()
         f = HelpFormatter()
@@ -276,21 +278,19 @@ class test_Command(AppCase):
         cmd.show_body = False
         cmd.show_body = False
         cmd.say_chat('->', 'foo', 'body')
         cmd.say_chat('->', 'foo', 'body')
 
 
+    @depends_on_current_app
     def test_with_cmdline_config(self):
     def test_with_cmdline_config(self):
         cmd = MockCommand()
         cmd = MockCommand()
-        try:
-            cmd.enable_config_from_cmdline = True
-            cmd.namespace = 'celeryd'
-            rest = cmd.setup_app_from_commandline(argv=[
-                '--loglevel=INFO', '--',
-                'broker.url=amqp://broker.example.com',
-                '.prefetch_multiplier=100'])
-            self.assertEqual(cmd.app.conf.BROKER_URL,
-                             'amqp://broker.example.com')
-            self.assertEqual(cmd.app.conf.CELERYD_PREFETCH_MULTIPLIER, 100)
-            self.assertListEqual(rest, ['--loglevel=INFO'])
-        finally:
-            cmd.app.conf.BROKER_URL = 'memory://'
+        cmd.enable_config_from_cmdline = True
+        cmd.namespace = 'celeryd'
+        rest = cmd.setup_app_from_commandline(argv=[
+            '--loglevel=INFO', '--',
+            'broker.url=amqp://broker.example.com',
+            '.prefetch_multiplier=100'])
+        self.assertEqual(cmd.app.conf.BROKER_URL,
+                         'amqp://broker.example.com')
+        self.assertEqual(cmd.app.conf.CELERYD_PREFETCH_MULTIPLIER, 100)
+        self.assertListEqual(rest, ['--loglevel=INFO'])
 
 
     def test_find_app(self):
     def test_find_app(self):
         cmd = MockCommand()
         cmd = MockCommand()

+ 5 - 7
celery/tests/bin/test_beat.py

@@ -70,13 +70,11 @@ class test_Beat(AppCase):
         self.assertEqual(b2.loglevel, logging.DEBUG)
         self.assertEqual(b2.loglevel, logging.DEBUG)
 
 
     def test_colorize(self):
     def test_colorize(self):
-        from celery import Celery
-        app = Celery(set_as_current=False)
-        app.log.setup = Mock()
-        b = beatapp.Beat(app=app, no_color=True)
+        self.app.log.setup = Mock()
+        b = beatapp.Beat(app=self.app, no_color=True)
         b.setup_logging()
         b.setup_logging()
-        self.assertTrue(app.log.setup.called)
-        self.assertEqual(app.log.setup.call_args[1]['colorize'], False)
+        self.assertTrue(self.app.log.setup.called)
+        self.assertEqual(self.app.log.setup.call_args[1]['colorize'], False)
 
 
     def test_init_loader(self):
     def test_init_loader(self):
         b = beatapp.Beat(app=self.app)
         b = beatapp.Beat(app=self.app)
@@ -179,7 +177,7 @@ class test_div(AppCase):
     def test_main(self):
     def test_main(self):
         sys.argv = [sys.argv[0], '-s', 'foo']
         sys.argv = [sys.argv[0], '-s', 'foo']
         try:
         try:
-            beat_bin.main()
+            beat_bin.main(app=self.app)
             self.assertTrue(MockBeat.running)
             self.assertTrue(MockBeat.running)
         finally:
         finally:
             MockBeat.running = False
             MockBeat.running = False

+ 30 - 25
celery/tests/bin/test_celery.py

@@ -7,7 +7,6 @@ from datetime import datetime
 from mock import Mock, patch
 from mock import Mock, patch
 
 
 from celery import __main__
 from celery import __main__
-from celery import task
 from celery.platforms import EX_FAILURE, EX_USAGE, EX_OK
 from celery.platforms import EX_FAILURE, EX_USAGE, EX_OK
 from celery.bin.base import Error
 from celery.bin.base import Error
 from celery.bin.celery import (
 from celery.bin.celery import (
@@ -30,15 +29,10 @@ from celery.bin.celery import (
     command,
     command,
 )
 )
 
 
-from celery.tests.case import AppCase, Case, WhateverIO, override_stdouts
+from celery.tests.case import AppCase, WhateverIO, override_stdouts
 
 
 
 
-@task()
-def add(x, y):
-    return x + y
-
-
-class test__main__(Case):
+class test__main__(AppCase):
 
 
     def test_warn_deprecated(self):
     def test_warn_deprecated(self):
         with override_stdouts() as (stdout, _):
         with override_stdouts() as (stdout, _):
@@ -167,28 +161,35 @@ class test_list(AppCase):
 
 
 class test_call(AppCase):
 class test_call(AppCase):
 
 
+    def setup(self):
+
+        @self.app.task(shared=False)
+        def add(x, y):
+            return x + y
+        self.add = add
+
     @patch('celery.app.base.Celery.send_task')
     @patch('celery.app.base.Celery.send_task')
     def test_run(self, send_task):
     def test_run(self, send_task):
         a = call(app=self.app, stderr=WhateverIO(), stdout=WhateverIO())
         a = call(app=self.app, stderr=WhateverIO(), stdout=WhateverIO())
-        a.run('tasks.add')
+        a.run(self.add.name)
         self.assertTrue(send_task.called)
         self.assertTrue(send_task.called)
 
 
-        a.run('tasks.add',
+        a.run(self.add.name,
               args=dumps([4, 4]),
               args=dumps([4, 4]),
               kwargs=dumps({'x': 2, 'y': 2}))
               kwargs=dumps({'x': 2, 'y': 2}))
         self.assertEqual(send_task.call_args[1]['args'], [4, 4])
         self.assertEqual(send_task.call_args[1]['args'], [4, 4])
         self.assertEqual(send_task.call_args[1]['kwargs'], {'x': 2, 'y': 2})
         self.assertEqual(send_task.call_args[1]['kwargs'], {'x': 2, 'y': 2})
 
 
-        a.run('tasks.add', expires=10, countdown=10)
+        a.run(self.add.name, expires=10, countdown=10)
         self.assertEqual(send_task.call_args[1]['expires'], 10)
         self.assertEqual(send_task.call_args[1]['expires'], 10)
         self.assertEqual(send_task.call_args[1]['countdown'], 10)
         self.assertEqual(send_task.call_args[1]['countdown'], 10)
 
 
         now = datetime.now()
         now = datetime.now()
         iso = now.isoformat()
         iso = now.isoformat()
-        a.run('tasks.add', expires=iso)
+        a.run(self.add.name, expires=iso)
         self.assertEqual(send_task.call_args[1]['expires'], now)
         self.assertEqual(send_task.call_args[1]['expires'], now)
         with self.assertRaises(ValueError):
         with self.assertRaises(ValueError):
-            a.run('tasks.add', expires='foobaribazibar')
+            a.run(self.add.name, expires='foobaribazibar')
 
 
 
 
 class test_purge(AppCase):
 class test_purge(AppCase):
@@ -208,6 +209,13 @@ class test_purge(AppCase):
 
 
 class test_result(AppCase):
 class test_result(AppCase):
 
 
+    def setup(self):
+
+        @self.app.task(shared=False)
+        def add(x, y):
+            return x + y
+        self.add = add
+
     def test_run(self):
     def test_run(self):
         with patch('celery.result.AsyncResult.get') as get:
         with patch('celery.result.AsyncResult.get') as get:
             out = WhateverIO()
             out = WhateverIO()
@@ -217,11 +225,11 @@ class test_result(AppCase):
             self.assertIn('Jerry', out.getvalue())
             self.assertIn('Jerry', out.getvalue())
 
 
             get.return_value = 'Elaine'
             get.return_value = 'Elaine'
-            r.run('id', task=add.name)
+            r.run('id', task=self.add.name)
             self.assertIn('Elaine', out.getvalue())
             self.assertIn('Elaine', out.getvalue())
 
 
             with patch('celery.result.AsyncResult.traceback') as tb:
             with patch('celery.result.AsyncResult.traceback') as tb:
-                r.run('id', task=add.name, traceback=True)
+                r.run('id', task=self.add.name, traceback=True)
                 self.assertIn(str(tb), out.getvalue())
                 self.assertIn(str(tb), out.getvalue())
 
 
 
 
@@ -417,15 +425,12 @@ class test_inspect(AppCase):
         self.assertTrue(inspect(app=self.app).epilog)
         self.assertTrue(inspect(app=self.app).epilog)
 
 
     def test_do_call_method_sql_transport_type(self):
     def test_do_call_method_sql_transport_type(self):
-        prev, self.app.connection = self.app.connection, Mock()
-        try:
-            conn = self.app.connection.return_value = Mock(name='Connection')
-            conn.transport.driver_type = 'sql'
-            i = inspect(app=self.app)
-            with self.assertRaises(i.Error):
-                i.do_call_method(['ping'])
-        finally:
-            self.app.connection = prev
+        self.app.connection = Mock()
+        conn = self.app.connection.return_value = Mock(name='Connection')
+        conn.transport.driver_type = 'sql'
+        i = inspect(app=self.app)
+        with self.assertRaises(i.Error):
+            i.do_call_method(['ping'])
 
 
     def test_say_directions(self):
     def test_say_directions(self):
         i = inspect(self.app)
         i = inspect(self.app)
@@ -561,7 +566,7 @@ class test_main(AppCase):
         cmd.execute_from_commandline.assert_called_with(None)
         cmd.execute_from_commandline.assert_called_with(None)
 
 
 
 
-class test_compat(Case):
+class test_compat(AppCase):
 
 
     def test_compat_command_decorator(self):
     def test_compat_command_decorator(self):
         with patch('celery.bin.celery.CeleryCommand') as CC:
         with patch('celery.bin.celery.CeleryCommand') as CC:

+ 10 - 9
celery/tests/bin/test_celeryd_detach.py

@@ -9,11 +9,11 @@ from celery.bin.celeryd_detach import (
     main,
     main,
 )
 )
 
 
-from celery.tests.case import Case, override_stdouts
+from celery.tests.case import AppCase, override_stdouts
 
 
 
 
 if not IS_WINDOWS:
 if not IS_WINDOWS:
-    class test_detached(Case):
+    class test_detached(AppCase):
 
 
         @patch('celery.bin.celeryd_detach.detached')
         @patch('celery.bin.celeryd_detach.detached')
         @patch('os.execv')
         @patch('os.execv')
@@ -32,17 +32,17 @@ if not IS_WINDOWS:
 
 
             execv.side_effect = Exception('foo')
             execv.side_effect = Exception('foo')
             r = detach('/bin/boo', ['a', 'b', 'c'],
             r = detach('/bin/boo', ['a', 'b', 'c'],
-                       logfile='/var/log', pidfile='/var/pid')
+                       logfile='/var/log', pidfile='/var/pid', app=self.app)
             context.__enter__.assert_called_with()
             context.__enter__.assert_called_with()
             self.assertTrue(logger.critical.called)
             self.assertTrue(logger.critical.called)
             setup_logs.assert_called_with('ERROR', '/var/log')
             setup_logs.assert_called_with('ERROR', '/var/log')
             self.assertEqual(r, 1)
             self.assertEqual(r, 1)
 
 
 
 
-class test_PartialOptionParser(Case):
+class test_PartialOptionParser(AppCase):
 
 
     def test_parser(self):
     def test_parser(self):
-        x = detached_celeryd()
+        x = detached_celeryd(self.app)
         p = x.Parser('celeryd_detach')
         p = x.Parser('celeryd_detach')
         options, values = p.parse_args(['--logfile=foo', '--fake', '--enable',
         options, values = p.parse_args(['--logfile=foo', '--fake', '--enable',
                                         'a', 'b', '-c1', '-d', '2'])
                                         'a', 'b', '-c1', '-d', '2'])
@@ -64,13 +64,13 @@ class test_PartialOptionParser(Case):
         p.get_option('--logfile').nargs = 1
         p.get_option('--logfile').nargs = 1
 
 
 
 
-class test_Command(Case):
+class test_Command(AppCase):
     argv = ['--autoscale=10,2', '-c', '1',
     argv = ['--autoscale=10,2', '-c', '1',
             '--logfile=/var/log', '-lDEBUG',
             '--logfile=/var/log', '-lDEBUG',
             '--', '.disable_rate_limits=1']
             '--', '.disable_rate_limits=1']
 
 
     def test_parse_options(self):
     def test_parse_options(self):
-        x = detached_celeryd()
+        x = detached_celeryd(app=self.app)
         o, v, l = x.parse_options('cd', self.argv)
         o, v, l = x.parse_options('cd', self.argv)
         self.assertEqual(o.logfile, '/var/log')
         self.assertEqual(o.logfile, '/var/log')
         self.assertEqual(l, ['--autoscale=10,2', '-c', '1',
         self.assertEqual(l, ['--autoscale=10,2', '-c', '1',
@@ -81,7 +81,7 @@ class test_Command(Case):
     @patch('sys.exit')
     @patch('sys.exit')
     @patch('celery.bin.celeryd_detach.detach')
     @patch('celery.bin.celeryd_detach.detach')
     def test_execute_from_commandline(self, detach, exit):
     def test_execute_from_commandline(self, detach, exit):
-        x = detached_celeryd()
+        x = detached_celeryd(app=self.app)
         x.execute_from_commandline(self.argv)
         x.execute_from_commandline(self.argv)
         self.assertTrue(exit.called)
         self.assertTrue(exit.called)
         detach.assert_called_with(
         detach.assert_called_with(
@@ -92,10 +92,11 @@ class test_Command(Case):
                 '--logfile=/var/log', '--pidfile=celeryd.pid',
                 '--logfile=/var/log', '--pidfile=celeryd.pid',
                 '--', '.disable_rate_limits=1'
                 '--', '.disable_rate_limits=1'
             ],
             ],
+            app=self.app,
         )
         )
 
 
     @patch('celery.bin.celeryd_detach.detached_celeryd')
     @patch('celery.bin.celeryd_detach.detached_celeryd')
     def test_main(self, command):
     def test_main(self, command):
         c = command.return_value = Mock()
         c = command.return_value = Mock()
-        main()
+        main(self.app)
         c.execute_from_commandline.assert_called_with()
         c.execute_from_commandline.assert_called_with()

+ 4 - 4
celery/tests/bin/test_celeryevdump.py

@@ -9,12 +9,12 @@ from celery.events.dumper import (
     evdump,
     evdump,
 )
 )
 
 
-from celery.tests.case import Case, WhateverIO
+from celery.tests.case import AppCase, WhateverIO
 
 
 
 
-class test_Dumper(Case):
+class test_Dumper(AppCase):
 
 
-    def setUp(self):
+    def setup(self):
         self.out = WhateverIO()
         self.out = WhateverIO()
         self.dumper = Dumper(out=self.out)
         self.dumper = Dumper(out=self.out)
 
 
@@ -44,7 +44,7 @@ class test_Dumper(Case):
     @patch('celery.events.EventReceiver.capture')
     @patch('celery.events.EventReceiver.capture')
     def test_evdump(self, capture):
     def test_evdump(self, capture):
         capture.side_effect = KeyboardInterrupt()
         capture.side_effect = KeyboardInterrupt()
-        evdump()
+        evdump(app=self.app)
 
 
     def test_evdump_error_handler(self):
     def test_evdump_error_handler(self):
         app = Mock(name='app')
         app = Mock(name='app')

+ 6 - 6
celery/tests/bin/test_multi.py

@@ -19,10 +19,10 @@ from celery.bin.multi import (
     __doc__ as doc,
     __doc__ as doc,
 )
 )
 
 
-from celery.tests.case import Case, WhateverIO
+from celery.tests.case import AppCase, WhateverIO
 
 
 
 
-class test_functions(Case):
+class test_functions(AppCase):
 
 
     def test_findsig(self):
     def test_findsig(self):
         self.assertEqual(findsig(['a', 'b', 'c', '-1']), 1)
         self.assertEqual(findsig(['a', 'b', 'c', '-1']), 1)
@@ -57,7 +57,7 @@ class test_functions(Case):
         self.assertEqual(quote("the 'quick"), "'the '\\''quick'")
         self.assertEqual(quote("the 'quick"), "'the '\\''quick'")
 
 
 
 
-class test_NamespacedOptionParser(Case):
+class test_NamespacedOptionParser(AppCase):
 
 
     def test_parse(self):
     def test_parse(self):
         x = NamespacedOptionParser(['-c:1,3', '4'])
         x = NamespacedOptionParser(['-c:1,3', '4'])
@@ -76,7 +76,7 @@ class test_NamespacedOptionParser(Case):
         self.assertEqual(x.passthrough, '-- .disable_rate_limits=1')
         self.assertEqual(x.passthrough, '-- .disable_rate_limits=1')
 
 
 
 
-class test_multi_args(Case):
+class test_multi_args(AppCase):
 
 
     @patch('socket.gethostname')
     @patch('socket.gethostname')
     def test_parse(self, gethostname):
     def test_parse(self, gethostname):
@@ -160,9 +160,9 @@ class test_multi_args(Case):
         )
         )
 
 
 
 
-class test_MultiTool(Case):
+class test_MultiTool(AppCase):
 
 
-    def setUp(self):
+    def setup(self):
         self.fh = WhateverIO()
         self.fh = WhateverIO()
         self.env = {}
         self.env = {}
         self.t = MultiTool(env=self.env, fh=self.fh)
         self.t = MultiTool(env=self.env, fh=self.fh)

+ 53 - 74
celery/tests/bin/test_worker.py

@@ -12,7 +12,6 @@ from nose import SkipTest
 from billiard import current_process
 from billiard import current_process
 from kombu import Exchange, Queue
 from kombu import Exchange, Queue
 
 
-from celery import Celery
 from celery import platforms
 from celery import platforms
 from celery import signals
 from celery import signals
 from celery.app import trace
 from celery.app import trace
@@ -68,33 +67,27 @@ class Worker(cd.Worker):
 class test_Worker(WorkerAppCase):
 class test_Worker(WorkerAppCase):
     Worker = Worker
     Worker = Worker
 
 
-    def teardown(self):
-        self.app.conf.CELERY_INCLUDE = ()
-
     @disable_stdouts
     @disable_stdouts
     def test_queues_string(self):
     def test_queues_string(self):
-        celery = Celery(set_as_current=False)
-        w = celery.Worker()
+        w = self.app.Worker()
         w.setup_queues('foo,bar,baz')
         w.setup_queues('foo,bar,baz')
         self.assertEqual(w.queues, ['foo', 'bar', 'baz'])
         self.assertEqual(w.queues, ['foo', 'bar', 'baz'])
-        self.assertTrue('foo' in celery.amqp.queues)
+        self.assertTrue('foo' in self.app.amqp.queues)
 
 
     @disable_stdouts
     @disable_stdouts
     def test_cpu_count(self):
     def test_cpu_count(self):
-        celery = Celery(set_as_current=False)
         with patch('celery.worker.cpu_count') as cpu_count:
         with patch('celery.worker.cpu_count') as cpu_count:
             cpu_count.side_effect = NotImplementedError()
             cpu_count.side_effect = NotImplementedError()
-            w = celery.Worker(concurrency=None)
+            w = self.app.Worker(concurrency=None)
             self.assertEqual(w.concurrency, 2)
             self.assertEqual(w.concurrency, 2)
-        w = celery.Worker(concurrency=5)
+        w = self.app.Worker(concurrency=5)
         self.assertEqual(w.concurrency, 5)
         self.assertEqual(w.concurrency, 5)
 
 
     @disable_stdouts
     @disable_stdouts
     def test_windows_B_option(self):
     def test_windows_B_option(self):
-        celery = Celery(set_as_current=False)
-        celery.IS_WINDOWS = True
+        self.app.IS_WINDOWS = True
         with self.assertRaises(SystemExit):
         with self.assertRaises(SystemExit):
-            worker(app=celery).run(beat=True)
+            worker(app=self.app).run(beat=True)
 
 
     def test_setup_concurrency_very_early(self):
     def test_setup_concurrency_very_early(self):
         x = worker()
         x = worker()
@@ -124,26 +117,23 @@ class test_Worker(WorkerAppCase):
 
 
     @disable_stdouts
     @disable_stdouts
     def test_invalid_loglevel_gives_error(self):
     def test_invalid_loglevel_gives_error(self):
-        x = worker(app=Celery(set_as_current=False))
+        x = worker(app=self.app)
         with self.assertRaises(SystemExit):
         with self.assertRaises(SystemExit):
             x.run(loglevel='GRIM_REAPER')
             x.run(loglevel='GRIM_REAPER')
 
 
     def test_no_loglevel(self):
     def test_no_loglevel(self):
-        app = Celery(set_as_current=False)
-        app.Worker = Mock()
-        worker(app=app).run(loglevel=None)
+        self.app.Worker = Mock()
+        worker(app=self.app).run(loglevel=None)
 
 
     def test_tasklist(self):
     def test_tasklist(self):
-        celery = Celery(set_as_current=False)
-        worker = celery.Worker()
+        worker = self.app.Worker()
         self.assertTrue(worker.app.tasks)
         self.assertTrue(worker.app.tasks)
         self.assertTrue(worker.app.finalized)
         self.assertTrue(worker.app.finalized)
         self.assertTrue(worker.tasklist(include_builtins=True))
         self.assertTrue(worker.tasklist(include_builtins=True))
         worker.tasklist(include_builtins=False)
         worker.tasklist(include_builtins=False)
 
 
     def test_extra_info(self):
     def test_extra_info(self):
-        celery = Celery(set_as_current=False)
-        worker = celery.Worker()
+        worker = self.app.Worker()
         worker.loglevel = logging.WARNING
         worker.loglevel = logging.WARNING
         self.assertFalse(worker.extra_info())
         self.assertFalse(worker.extra_info())
         worker.loglevel = logging.INFO
         worker.loglevel = logging.INFO
@@ -154,6 +144,7 @@ class test_Worker(WorkerAppCase):
         worker = self.Worker(app=self.app, loglevel='INFO')
         worker = self.Worker(app=self.app, loglevel='INFO')
         self.assertEqual(worker.loglevel, logging.INFO)
         self.assertEqual(worker.loglevel, logging.INFO)
 
 
+    @disable_stdouts
     def test_run_worker(self):
     def test_run_worker(self):
         handlers = {}
         handlers = {}
 
 
@@ -193,29 +184,21 @@ class test_Worker(WorkerAppCase):
         worker.autoscale = 13, 10
         worker.autoscale = 13, 10
         self.assertTrue(worker.startup_info())
         self.assertTrue(worker.startup_info())
 
 
-        app = Celery(set_as_current=False)
-        worker = self.Worker(app=app, queues='foo,bar,baz,xuzzy,do,re,mi')
-        prev, app.loader = app.loader, Mock()
-        try:
-            app.loader.__module__ = 'acme.baked_beans'
-            self.assertTrue(worker.startup_info())
-        finally:
-            app.loader = prev
+        prev_loader = self.app.loader
+        worker = self.Worker(app=self.app, queues='foo,bar,baz,xuzzy,do,re,mi')
+        self.app.loader = Mock()
+        self.app.loader.__module__ = 'acme.baked_beans'
+        self.assertTrue(worker.startup_info())
 
 
-        prev, app.loader = app.loader, Mock()
-        try:
-            app.loader.__module__ = 'celery.loaders.foo'
-            self.assertTrue(worker.startup_info())
-        finally:
-            app.loader = prev
+        self.app.loader = Mock()
+        self.app.loader.__module__ = 'celery.loaders.foo'
+        self.assertTrue(worker.startup_info())
 
 
         from celery.loaders.app import AppLoader
         from celery.loaders.app import AppLoader
-        prev, app.loader = app.loader, AppLoader(app=self.app)
-        try:
-            self.assertTrue(worker.startup_info())
-        finally:
-            app.loader = prev
+        self.app.loader = AppLoader(app=self.app)
+        self.assertTrue(worker.startup_info())
 
 
+        self.app.loader = prev_loader
         worker.send_events = True
         worker.send_events = True
         self.assertTrue(worker.startup_info())
         self.assertTrue(worker.startup_info())
 
 
@@ -239,32 +222,32 @@ class test_Worker(WorkerAppCase):
     def test_init_queues(self):
     def test_init_queues(self):
         app = self.app
         app = self.app
         c = app.conf
         c = app.conf
-        p, app.amqp.queues = app.amqp.queues, app.amqp.Queues({
+        app.amqp.queues = app.amqp.Queues({
             'celery': {'exchange': 'celery',
             'celery': {'exchange': 'celery',
                        'routing_key': 'celery'},
                        'routing_key': 'celery'},
             'video': {'exchange': 'video',
             'video': {'exchange': 'video',
-                      'routing_key': 'video'}})
-        try:
-            worker = self.Worker(app=self.app)
-            worker.setup_queues(['video'])
-            self.assertIn('video', app.amqp.queues)
-            self.assertIn('video', app.amqp.queues.consume_from)
-            self.assertIn('celery', app.amqp.queues)
-            self.assertNotIn('celery', app.amqp.queues.consume_from)
-
-            c.CELERY_CREATE_MISSING_QUEUES = False
-            del(app.amqp.queues)
-            with self.assertRaises(ImproperlyConfigured):
-                self.Worker(app=self.app).setup_queues(['image'])
-            del(app.amqp.queues)
-            c.CELERY_CREATE_MISSING_QUEUES = True
-            worker = self.Worker(app=self.app)
-            worker.setup_queues(queues=['image'])
-            self.assertIn('image', app.amqp.queues.consume_from)
-            self.assertEqual(Queue('image', Exchange('image'),
-                             routing_key='image'), app.amqp.queues['image'])
-        finally:
-            app.amqp.queues = p
+                      'routing_key': 'video'},
+        })
+        worker = self.Worker(app=self.app)
+        worker.setup_queues(['video'])
+        self.assertIn('video', app.amqp.queues)
+        self.assertIn('video', app.amqp.queues.consume_from)
+        self.assertIn('celery', app.amqp.queues)
+        self.assertNotIn('celery', app.amqp.queues.consume_from)
+
+        c.CELERY_CREATE_MISSING_QUEUES = False
+        del(app.amqp.queues)
+        with self.assertRaises(ImproperlyConfigured):
+            self.Worker(app=self.app).setup_queues(['image'])
+        del(app.amqp.queues)
+        c.CELERY_CREATE_MISSING_QUEUES = True
+        worker = self.Worker(app=self.app)
+        worker.setup_queues(queues=['image'])
+        self.assertIn('image', app.amqp.queues.consume_from)
+        self.assertEqual(
+            Queue('image', Exchange('image'), routing_key='image'),
+            app.amqp.queues['image'],
+        )
 
 
     @disable_stdouts
     @disable_stdouts
     def test_autoscale_argument(self):
     def test_autoscale_argument(self):
@@ -272,6 +255,7 @@ class test_Worker(WorkerAppCase):
         self.assertListEqual(worker1.autoscale, [10, 3])
         self.assertListEqual(worker1.autoscale, [10, 3])
         worker2 = self.Worker(app=self.app, autoscale='10')
         worker2 = self.Worker(app=self.app, autoscale='10')
         self.assertListEqual(worker2.autoscale, [10, 0])
         self.assertListEqual(worker2.autoscale, [10, 0])
+        self.assert_no_logging_side_effect()
 
 
     def test_include_argument(self):
     def test_include_argument(self):
         worker1 = self.Worker(app=self.app, include='some.module')
         worker1 = self.Worker(app=self.app, include='some.module')
@@ -318,16 +302,11 @@ class test_Worker(WorkerAppCase):
 
 
     @disable_stdouts
     @disable_stdouts
     def test_on_start_custom_logging(self):
     def test_on_start_custom_logging(self):
-        prev, self.app.log.redirect_stdouts = (
-            self.app.log.redirect_stdouts, Mock(),
-        )
-        try:
-            worker = self.Worker(app=self.app, redirect_stoutds=True)
-            worker._custom_logging = True
-            worker.on_start()
-            self.assertFalse(self.app.log.redirect_stdouts.called)
-        finally:
-            self.app.log.redirect_stdouts = prev
+        self.app.log.redirect_stdouts = Mock()
+        worker = self.Worker(app=self.app, redirect_stoutds=True)
+        worker._custom_logging = True
+        worker.on_start()
+        self.assertFalse(self.app.log.redirect_stdouts.called)
 
 
     def test_setup_logging_no_color(self):
     def test_setup_logging_no_color(self):
         worker = self.Worker(
         worker = self.Worker(
@@ -463,7 +442,7 @@ class test_funs(WorkerAppCase):
         p, cd.Worker = cd.Worker, Worker
         p, cd.Worker = cd.Worker, Worker
         s, sys.argv = sys.argv, ['worker', '--discard']
         s, sys.argv = sys.argv, ['worker', '--discard']
         try:
         try:
-            worker_main()
+            worker_main(app=self.app)
         finally:
         finally:
             cd.Worker = p
             cd.Worker = p
             sys.argv = s
             sys.argv = s

+ 128 - 47
celery/tests/case.py

@@ -9,6 +9,7 @@ except AttributeError:
     from unittest2.util import safe_repr, unorderable_list_difference  # noqa
     from unittest2.util import safe_repr, unorderable_list_difference  # noqa
 
 
 import importlib
 import importlib
+import inspect
 import logging
 import logging
 import os
 import os
 import platform
 import platform
@@ -18,6 +19,7 @@ import time
 import warnings
 import warnings
 
 
 from contextlib import contextmanager
 from contextlib import contextmanager
+from copy import deepcopy
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
 from functools import partial, wraps
 from functools import partial, wraps
 from types import ModuleType
 from types import ModuleType
@@ -27,28 +29,93 @@ try:
 except ImportError:
 except ImportError:
     import mock  # noqa
     import mock  # noqa
 from nose import SkipTest
 from nose import SkipTest
+from kombu import Queue
 from kombu.log import NullHandler
 from kombu.log import NullHandler
-from kombu.utils import nested
+from kombu.utils import nested, symbol_by_name
 
 
+from celery import Celery
+from celery.app import current_app
+from celery.backends.cache import CacheBackend, DummyClient
 from celery.five import (
 from celery.five import (
     WhateverIO, builtins, items, reraise,
     WhateverIO, builtins, items, reraise,
     string_t, values, open_fqdn,
     string_t, values, open_fqdn,
 )
 )
 from celery.utils.functional import noop
 from celery.utils.functional import noop
+from celery.utils.imports import qualname
 
 
 __all__ = [
 __all__ = [
     'Case', 'AppCase', 'Mock', 'patch', 'call', 'skip_unless_module',
     'Case', 'AppCase', 'Mock', 'patch', 'call', 'skip_unless_module',
-    'wrap_logger', 'eager_tasks', 'with_environ', 'sleepdeprived',
+    'wrap_logger', 'with_environ', 'sleepdeprived',
     'skip_if_environ', 'skip_if_quick', 'todo', 'skip', 'skip_if',
     'skip_if_environ', 'skip_if_quick', 'todo', 'skip', 'skip_if',
     'skip_unless', 'mask_modules', 'override_stdouts', 'mock_module',
     'skip_unless', 'mask_modules', 'override_stdouts', 'mock_module',
     'replace_module_value', 'sys_platform', 'reset_modules',
     'replace_module_value', 'sys_platform', 'reset_modules',
     'patch_modules', 'mock_context', 'mock_open', 'patch_many',
     'patch_modules', 'mock_context', 'mock_open', 'patch_many',
-    'patch_settings', 'assert_signal_called', 'skip_if_pypy',
+    'assert_signal_called', 'skip_if_pypy',
     'skip_if_jython', 'body_from_sig', 'restore_logging',
     'skip_if_jython', 'body_from_sig', 'restore_logging',
 ]
 ]
 patch = mock.patch
 patch = mock.patch
 call = mock.call
 call = mock.call
 
 
+CASE_REDEFINES_SETUP = """\
+{name} (subclass of AppCase) redefines private "setUp", should be: "setup"\
+"""
+CASE_REDEFINES_TEARDOWN = """\
+{name} (subclass of AppCase) redefines private "tearDown", \
+should be: "teardown"\
+"""
+
+CELERY_TEST_CONFIG = {
+    #: Don't want log output when running suite.
+    'CELERYD_HIJACK_ROOT_LOGGER': False,
+    'CELERY_SEND_TASK_ERROR_EMAILS': False,
+    'CELERY_DEFAULT_QUEUE': 'testcelery',
+    'CELERY_DEFAULT_EXCHANGE': 'testcelery',
+    'CELERY_DEFAULT_ROUTING_KEY': 'testcelery',
+    'CELERY_QUEUES': (
+        Queue('testcelery', routing_key='testcelery'),
+    ),
+    'CELERY_ENABLE_UTC': True,
+    'CELERY_TIMEZONE': 'UTC',
+    'CELERYD_LOG_COLOR': False,
+
+    # Mongo results tests (only executed if installed and running)
+    'CELERY_MONGODB_BACKEND_SETTINGS': {
+        'host': os.environ.get('MONGO_HOST') or 'localhost',
+        'port': os.environ.get('MONGO_PORT') or 27017,
+        'database': os.environ.get('MONGO_DB') or 'celery_unittests',
+        'taskmeta_collection': (os.environ.get('MONGO_TASKMETA_COLLECTION')
+                                or 'taskmeta_collection'),
+        'user': os.environ.get('MONGO_USER'),
+        'password': os.environ.get('MONGO_PASSWORD'),
+    }
+}
+
+
+class Trap(object):
+
+    def __getattr__(self, name):
+        raise RuntimeError('Test depends on current_app')
+
+
+class UnitLogging(symbol_by_name(Celery.log_cls)):
+
+    def __init__(self, *args, **kwargs):
+        super(UnitLogging, self).__init__(*args, **kwargs)
+        self.already_setup = True
+
+
+def UnitApp(name=None, broker=None, backend=None,
+            set_as_current=False, log=UnitLogging, **kwargs):
+
+    app = Celery(name or 'celery.tests',
+                 broker=broker or 'memory://',
+                 backend=backend or 'cache+memory://',
+                 set_as_current=set_as_current,
+                 log=log,
+                 **kwargs)
+    app.add_defaults(deepcopy(CELERY_TEST_CONFIG))
+    return app
+
 
 
 class Mock(mock.Mock):
 class Mock(mock.Mock):
 
 
@@ -204,29 +271,74 @@ class Case(unittest.TestCase):
             self.fail(self._formatMessage(msg, standardMsg))
             self.fail(self._formatMessage(msg, standardMsg))
 
 
 
 
+def depends_on_current_app(fun):
+    if inspect.isclass(fun):
+        fun.contained = False
+    else:
+        @wraps(fun)
+        def __inner(self, *args, **kwargs):
+            self.app.set_current()
+            return fun(self, *args, **kwargs)
+        return __inner
+
+
 class AppCase(Case):
 class AppCase(Case):
     contained = True
     contained = True
 
 
+    def __new__(cls, *args, **kwargs):
+        if cls.__dict__.get('setUp'):
+            raise RuntimeError(CASE_REDEFINES_SETUP.format(name=qualname(cls)))
+        if cls.__dict__.get('tearDown'):
+            raise RuntimeError(CASE_REDEFINES_TEARDOWN.format(
+                name=qualname(cls)),
+            )
+        return super(AppCase, cls).__new__(cls, *args, **kwargs)
+
+    def Celery(self, *args, **kwargs):
+        return UnitApp(*args, **kwargs)
+
     def setUp(self):
     def setUp(self):
-        from celery import Celery
-        from celery.app import current_app
-        from celery.backends.cache import CacheBackend, DummyClient
+        from celery import _state
         self._current_app = current_app()
         self._current_app = current_app()
-        app = self.app = (Celery(set_as_current=False)
-                          if self.contained else self._current_app)
-        if isinstance(app.backend, CacheBackend):
-            if isinstance(app.backend.client, DummyClient):
-                app.backend.client.cache.clear()
-        app.backend._cache.clear()
+        self._default_app = _state.default_app
+        trap = Trap()
+        _state.set_default_app(trap)
+        _state._tls.current_app = trap
+
+        self.app = self.Celery(set_as_current=False)
+        if not self.contained:
+            self.app.set_current()
         root = logging.getLogger()
         root = logging.getLogger()
         self.__rootlevel = root.level
         self.__rootlevel = root.level
         self.__roothandlers = root.handlers
         self.__roothandlers = root.handlers
-        self.setup()
+        try:
+            self.setup()
+        except:
+            self._teardown_app()
+            raise
+
+    def _teardown_app(self):
+        backend = self.app.__dict__.get('backend')
+        if backend is not None:
+            if isinstance(backend, CacheBackend):
+                if isinstance(backend.client, DummyClient):
+                    backend.client.cache.clear()
+                backend._cache.clear()
+        from celery._state import _tls, set_default_app
+        set_default_app(self._default_app)
+        _tls.current_app = self._current_app
+        if self.app is not self._current_app:
+            self.app.close()
+        self.app = None
 
 
     def tearDown(self):
     def tearDown(self):
-        self.teardown()
-        self._current_app.set_current()
+        try:
+            self.teardown()
+        finally:
+            self._teardown_app()
+        self.assert_no_logging_side_effect()
 
 
+    def assert_no_logging_side_effect(self):
         root = logging.getLogger()
         root = logging.getLogger()
         this = '.'.join([self.__class__.__name__, self._testMethodName])
         this = '.'.join([self.__class__.__name__, self._testMethodName])
         if root.level != self.__rootlevel:
         if root.level != self.__rootlevel:
@@ -258,16 +370,6 @@ def wrap_logger(logger, loglevel=logging.ERROR):
         logger.handlers = old_handlers
         logger.handlers = old_handlers
 
 
 
 
-@contextmanager
-def eager_tasks(app):
-    prev = app.conf.CELERY_ALWAYS_EAGER
-    app.conf.CELERY_ALWAYS_EAGER = True
-    try:
-        yield True
-    finally:
-        app.conf.CELERY_ALWAYS_EAGER = prev
-
-
 def with_environ(env_name, env_value):
 def with_environ(env_name, env_value):
 
 
     def _envpatched(fun):
     def _envpatched(fun):
@@ -279,8 +381,7 @@ def with_environ(env_name, env_value):
             try:
             try:
                 return fun(*args, **kwargs)
                 return fun(*args, **kwargs)
             finally:
             finally:
-                if prev_val is not None:
-                    os.environ[env_name] = prev_val
+                os.environ[env_name] = prev_val or ''
 
 
         return _patch_environ
         return _patch_environ
     return _envpatched
     return _envpatched
@@ -554,26 +655,6 @@ def patch_many(*targets):
     return nested(*[patch(target) for target in targets])
     return nested(*[patch(target) for target in targets])
 
 
 
 
-@contextmanager
-def patch_settings(app, **config):
-    if app is None:
-        from celery import current_app
-        app = current_app
-    prev = {}
-    for key, value in items(config):
-        try:
-            prev[key] = getattr(app.conf, key)
-        except AttributeError:
-            pass
-        setattr(app.conf, key, value)
-
-    try:
-        yield app.conf
-    finally:
-        for key, value in items(prev):
-            setattr(app.conf, key, value)
-
-
 @contextmanager
 @contextmanager
 def assert_signal_called(signal, **expected):
 def assert_signal_called(signal, **expected):
     handler = Mock()
     handler = Mock()

+ 63 - 10
celery/tests/compat_modules/test_compat.py

@@ -1,6 +1,15 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
 
 
-from celery.tests.case import AppCase
+from datetime import timedelta
+
+from celery.schedules import schedule
+from celery.task import (
+    periodic_task,
+    PeriodicTask
+)
+from celery.utils.timeutils import timedelta_seconds
+
+from celery.tests.case import AppCase, depends_on_current_app
 
 
 
 
 class test_Task(AppCase):
 class test_Task(AppCase):
@@ -11,16 +20,60 @@ class test_Task(AppCase):
         class timkX(OldTask):
         class timkX(OldTask):
             abstract = True
             abstract = True
 
 
-        app = Celery(set_as_current=False, accept_magic_kwargs=True)
-        timkX.bind(app)
-        # see #918
-        self.assertFalse(timkX.accept_magic_kwargs)
+        with self.Celery(set_as_current=False,
+                         accept_magic_kwargs=True) as app:
+            timkX.bind(app)
+            # see #918
+            self.assertFalse(timkX.accept_magic_kwargs)
 
 
-        from celery import Task as NewTask
+            from celery import Task as NewTask
 
 
-        class timkY(NewTask):
-            abstract = True
+            class timkY(NewTask):
+                abstract = True
+
+            timkY.bind(app)
+            self.assertFalse(timkY.accept_magic_kwargs)
+
+
+@depends_on_current_app
+class test_periodic_tasks(AppCase):
+
+    def setup(self):
+        @periodic_task(app=self.app, shared=False,
+                       run_every=schedule(timedelta(hours=1), app=self.app))
+        def my_periodic():
+            pass
+        self.my_periodic = my_periodic
+
+    def now(self):
+        return self.app.now()
+
+    def test_must_have_run_every(self):
+        with self.assertRaises(NotImplementedError):
+            type('Foo', (PeriodicTask, ), {'__module__': __name__})
+
+    def test_remaining_estimate(self):
+        s = self.my_periodic.run_every
+        self.assertIsInstance(
+            s.remaining_estimate(s.maybe_make_aware(self.now())),
+            timedelta)
+
+    def test_is_due_not_due(self):
+        due, remaining = self.my_periodic.run_every.is_due(self.now())
+        self.assertFalse(due)
+        # This assertion may fail if executed in the
+        # first minute of an hour, thus 59 instead of 60
+        self.assertGreater(remaining, 59)
 
 
-        timkY.bind(app)
-        self.assertFalse(timkY.accept_magic_kwargs)
+    def test_is_due(self):
+        p = self.my_periodic
+        due, remaining = p.run_every.is_due(
+            self.now() - p.run_every.run_every,
+        )
+        self.assertTrue(due)
+        self.assertEqual(remaining,
+                         timedelta_seconds(p.run_every.run_every))
 
 
+    def test_schedule_repr(self):
+        p = self.my_periodic
+        self.assertTrue(repr(p.run_every))

+ 6 - 7
celery/tests/compat_modules/test_compat_utils.py

@@ -1,14 +1,15 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
 
 
-
 import celery
 import celery
+
 from celery.app.task import Task as ModernTask
 from celery.app.task import Task as ModernTask
 from celery.task.base import Task as CompatTask
 from celery.task.base import Task as CompatTask
 
 
-from celery.tests.case import Case
+from celery.tests.case import AppCase, depends_on_current_app
 
 
 
 
-class test_MagicModule(Case):
+@depends_on_current_app
+class test_MagicModule(AppCase):
 
 
     def test_class_property_set_without_type(self):
     def test_class_property_set_without_type(self):
         self.assertTrue(ModernTask.__dict__['app'].__get__(CompatTask()))
         self.assertTrue(ModernTask.__dict__['app'].__get__(CompatTask()))
@@ -21,10 +22,8 @@ class test_MagicModule(Case):
 
 
         class X(CompatTask):
         class X(CompatTask):
             pass
             pass
-
-        app = celery.Celery(set_as_current=False)
-        ModernTask.__dict__['app'].__set__(X(), app)
-        self.assertEqual(X.app, app)
+        ModernTask.__dict__['app'].__set__(X(), self.app)
+        self.assertIs(X.app, self.app)
 
 
     def test_dir(self):
     def test_dir(self):
         self.assertTrue(dir(celery.messaging))
         self.assertTrue(dir(celery.messaging))

+ 4 - 3
celery/tests/compat_modules/test_decorators.py

@@ -4,21 +4,22 @@ import warnings
 
 
 from celery.task import base
 from celery.task import base
 
 
-from celery.tests.case import Case
+from celery.tests.case import AppCase, depends_on_current_app
 
 
 
 
 def add(x, y):
 def add(x, y):
     return x + y
     return x + y
 
 
 
 
-class test_decorators(Case):
+@depends_on_current_app
+class test_decorators(AppCase):
 
 
     def test_task_alias(self):
     def test_task_alias(self):
         from celery import task
         from celery import task
         self.assertTrue(task.__file__)
         self.assertTrue(task.__file__)
         self.assertTrue(task(add))
         self.assertTrue(task(add))
 
 
-    def setUp(self):
+    def setup(self):
         with warnings.catch_warnings(record=True):
         with warnings.catch_warnings(record=True):
             from celery import decorators
             from celery import decorators
             self.decorators = decorators
             self.decorators = decorators

+ 13 - 10
celery/tests/compat_modules/test_http.py

@@ -13,7 +13,7 @@ from kombu.utils.encoding import from_utf8
 
 
 from celery.five import StringIO, items
 from celery.five import StringIO, items
 from celery.task import http
 from celery.task import http
-from celery.tests.case import AppCase, Case, eager_tasks
+from celery.tests.case import AppCase, Case
 
 
 
 
 @contextmanager
 @contextmanager
@@ -140,16 +140,19 @@ class test_HttpDispatch(AppCase):
 
 
 
 
 class test_URL(AppCase):
 class test_URL(AppCase):
-    contained = False
 
 
     def test_URL_get_async(self):
     def test_URL_get_async(self):
-        with eager_tasks(self.app):
-            with mock_urlopen(success_response(100)):
-                d = http.URL('http://example.com/mul').get_async(x=10, y=10)
-                self.assertEqual(d.get(), 100)
+        self.app.conf.CELERY_ALWAYS_EAGER = True
+        with mock_urlopen(success_response(100)):
+            d = http.URL(
+                'http://example.com/mul', app=self.app,
+            ).get_async(x=10, y=10)
+            self.assertEqual(d.get(), 100)
 
 
     def test_URL_post_async(self):
     def test_URL_post_async(self):
-        with eager_tasks(self.app):
-            with mock_urlopen(success_response(100)):
-                d = http.URL('http://example.com/mul').post_async(x=10, y=10)
-                self.assertEqual(d.get(), 100)
+        self.app.conf.CELERY_ALWAYS_EAGER = True
+        with mock_urlopen(success_response(100)):
+            d = http.URL(
+                'http://example.com/mul', app=self.app,
+            ).post_async(x=10, y=10)
+            self.assertEqual(d.get(), 100)

+ 3 - 2
celery/tests/compat_modules/test_messaging.py

@@ -1,10 +1,11 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
 
 
 from celery import messaging
 from celery import messaging
-from celery.tests.case import Case
+from celery.tests.case import AppCase, depends_on_current_app
 
 
 
 
-class test_compat_messaging_module(Case):
+@depends_on_current_app
+class test_compat_messaging_module(AppCase):
 
 
     def test_get_consume_set(self):
     def test_get_consume_set(self):
         conn = messaging.establish_connection()
         conn = messaging.establish_connection()

+ 88 - 42
celery/tests/compat_modules/test_sets.py

@@ -1,44 +1,92 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
 
 
 import anyjson
 import anyjson
+import warnings
 
 
 from mock import Mock, patch
 from mock import Mock, patch
 
 
+from celery import uuid
+from celery.result import TaskSetResult
 from celery.task import Task
 from celery.task import Task
-from celery.task.sets import subtask, TaskSet
 from celery.canvas import Signature
 from celery.canvas import Signature
 
 
+from celery.tests.tasks.test_result import make_mock_group
 from celery.tests.case import AppCase
 from celery.tests.case import AppCase
 
 
 
 
-class MockTask(Task):
-    name = 'tasks.add'
+class SetsCase(AppCase):
 
 
-    def run(self, x, y, **kwargs):
-        return x + y
+    def setup(self):
+        with warnings.catch_warnings(record=True):
+            from celery.task import sets
+            self.sets = sets
+            self.subtask = sets.subtask
+            self.TaskSet = sets.TaskSet
 
 
-    @classmethod
-    def apply_async(cls, args, kwargs, **options):
-        return (args, kwargs, options)
+        class MockTask(Task):
+            app = self.app
+            name = 'tasks.add'
 
 
-    @classmethod
-    def apply(cls, args, kwargs, **options):
-        return (args, kwargs, options)
+            def run(self, x, y, **kwargs):
+                return x + y
 
 
+            @classmethod
+            def apply_async(cls, args, kwargs, **options):
+                return (args, kwargs, options)
 
 
-class test_subtask(AppCase):
+            @classmethod
+            def apply(cls, args, kwargs, **options):
+                return (args, kwargs, options)
+        self.MockTask = MockTask
+
+
+class test_TaskSetResult(AppCase):
+
+    def setup(self):
+        self.size = 10
+        self.ts = TaskSetResult(uuid(), make_mock_group(self.app, self.size))
+
+    def test_total(self):
+        self.assertEqual(self.ts.total, self.size)
+
+    def test_compat_properties(self):
+        self.assertEqual(self.ts.taskset_id, self.ts.id)
+        self.ts.taskset_id = 'foo'
+        self.assertEqual(self.ts.taskset_id, 'foo')
+
+    def test_compat_subtasks_kwarg(self):
+        x = TaskSetResult(uuid(), subtasks=[1, 2, 3])
+        self.assertEqual(x.results, [1, 2, 3])
+
+    def test_itersubtasks(self):
+        it = self.ts.itersubtasks()
+
+        for i, t in enumerate(it):
+            self.assertEqual(t.get(), i)
+
+
+class test_App(AppCase):
+
+    def test_TaskSet(self):
+        with warnings.catch_warnings(record=True):
+            ts = self.app.TaskSet()
+            self.assertListEqual(ts.tasks, [])
+            self.assertIs(ts.app, self.app)
+
+
+class test_subtask(SetsCase):
 
 
     def test_behaves_like_type(self):
     def test_behaves_like_type(self):
-        s = subtask('tasks.add', (2, 2), {'cache': True},
-                    {'routing_key': 'CPU-bound'})
-        self.assertDictEqual(subtask(s), s)
+        s = self.subtask('tasks.add', (2, 2), {'cache': True},
+                         {'routing_key': 'CPU-bound'})
+        self.assertDictEqual(self.subtask(s), s)
 
 
     def test_task_argument_can_be_task_cls(self):
     def test_task_argument_can_be_task_cls(self):
-        s = subtask(MockTask, (2, 2))
-        self.assertEqual(s.task, MockTask.name)
+        s = self.subtask(self.MockTask, (2, 2))
+        self.assertEqual(s.task, self.MockTask.name)
 
 
     def test_apply_async(self):
     def test_apply_async(self):
-        s = MockTask.subtask(
+        s = self.MockTask.subtask(
             (2, 2), {'cache': True}, {'routing_key': 'CPU-bound'},
             (2, 2), {'cache': True}, {'routing_key': 'CPU-bound'},
         )
         )
         args, kwargs, options = s.apply_async()
         args, kwargs, options = s.apply_async()
@@ -47,7 +95,7 @@ class test_subtask(AppCase):
         self.assertDictEqual(options, {'routing_key': 'CPU-bound'})
         self.assertDictEqual(options, {'routing_key': 'CPU-bound'})
 
 
     def test_delay_argmerge(self):
     def test_delay_argmerge(self):
-        s = MockTask.subtask(
+        s = self.MockTask.subtask(
             (2, ), {'cache': True}, {'routing_key': 'CPU-bound'},
             (2, ), {'cache': True}, {'routing_key': 'CPU-bound'},
         )
         )
         args, kwargs, options = s.delay(10, cache=False, other='foo')
         args, kwargs, options = s.delay(10, cache=False, other='foo')
@@ -56,7 +104,7 @@ class test_subtask(AppCase):
         self.assertDictEqual(options, {'routing_key': 'CPU-bound'})
         self.assertDictEqual(options, {'routing_key': 'CPU-bound'})
 
 
     def test_apply_async_argmerge(self):
     def test_apply_async_argmerge(self):
-        s = MockTask.subtask(
+        s = self.MockTask.subtask(
             (2, ), {'cache': True}, {'routing_key': 'CPU-bound'},
             (2, ), {'cache': True}, {'routing_key': 'CPU-bound'},
         )
         )
         args, kwargs, options = s.apply_async((10, ),
         args, kwargs, options = s.apply_async((10, ),
@@ -70,7 +118,7 @@ class test_subtask(AppCase):
                                        'exchange': 'fast'})
                                        'exchange': 'fast'})
 
 
     def test_apply_argmerge(self):
     def test_apply_argmerge(self):
-        s = MockTask.subtask(
+        s = self.MockTask.subtask(
             (2, ), {'cache': True}, {'routing_key': 'CPU-bound'},
             (2, ), {'cache': True}, {'routing_key': 'CPU-bound'},
         )
         )
         args, kwargs, options = s.apply((10, ),
         args, kwargs, options = s.apply((10, ),
@@ -85,50 +133,48 @@ class test_subtask(AppCase):
         )
         )
 
 
     def test_is_JSON_serializable(self):
     def test_is_JSON_serializable(self):
-        s = MockTask.subtask(
+        s = self.MockTask.subtask(
             (2, ), {'cache': True}, {'routing_key': 'CPU-bound'},
             (2, ), {'cache': True}, {'routing_key': 'CPU-bound'},
         )
         )
         s.args = list(s.args)                   # tuples are not preserved
         s.args = list(s.args)                   # tuples are not preserved
                                                 # but this doesn't matter.
                                                 # but this doesn't matter.
-        self.assertEqual(s, subtask(anyjson.loads(anyjson.dumps(s))))
+        self.assertEqual(s, self.subtask(anyjson.loads(anyjson.dumps(s))))
 
 
     def test_repr(self):
     def test_repr(self):
-        s = MockTask.subtask((2, ), {'cache': True})
+        s = self.MockTask.subtask((2, ), {'cache': True})
         self.assertIn('2', repr(s))
         self.assertIn('2', repr(s))
         self.assertIn('cache=True', repr(s))
         self.assertIn('cache=True', repr(s))
 
 
     def test_reduce(self):
     def test_reduce(self):
-        s = MockTask.subtask((2, ), {'cache': True})
+        s = self.MockTask.subtask((2, ), {'cache': True})
         cls, args = s.__reduce__()
         cls, args = s.__reduce__()
         self.assertDictEqual(dict(cls(*args)), dict(s))
         self.assertDictEqual(dict(cls(*args)), dict(s))
 
 
 
 
-class test_TaskSet(AppCase):
+class test_TaskSet(SetsCase):
 
 
     def test_task_arg_can_be_iterable__compat(self):
     def test_task_arg_can_be_iterable__compat(self):
-        ts = TaskSet([MockTask.subtask((i, i))
-                      for i in (2, 4, 8)], app=self.app)
+        ts = self.TaskSet([self.MockTask.subtask((i, i))
+                           for i in (2, 4, 8)], app=self.app)
         self.assertEqual(len(ts), 3)
         self.assertEqual(len(ts), 3)
 
 
     def test_respects_ALWAYS_EAGER(self):
     def test_respects_ALWAYS_EAGER(self):
         app = self.app
         app = self.app
 
 
-        class MockTaskSet(TaskSet):
+        class MockTaskSet(self.TaskSet):
             applied = 0
             applied = 0
 
 
             def apply(self, *args, **kwargs):
             def apply(self, *args, **kwargs):
                 self.applied += 1
                 self.applied += 1
 
 
         ts = MockTaskSet(
         ts = MockTaskSet(
-            [MockTask.subtask((i, i)) for i in (2, 4, 8)],
+            [self.MockTask.subtask((i, i)) for i in (2, 4, 8)],
             app=self.app,
             app=self.app,
         )
         )
         app.conf.CELERY_ALWAYS_EAGER = True
         app.conf.CELERY_ALWAYS_EAGER = True
-        try:
-            ts.apply_async()
-        finally:
-            app.conf.CELERY_ALWAYS_EAGER = False
+        ts.apply_async()
         self.assertEqual(ts.applied, 1)
         self.assertEqual(ts.applied, 1)
+        app.conf.CELERY_ALWAYS_EAGER = False
 
 
         with patch('celery.task.sets.get_current_worker_task') as gwt:
         with patch('celery.task.sets.get_current_worker_task') as gwt:
             parent = gwt.return_value = Mock()
             parent = gwt.return_value = Mock()
@@ -143,8 +189,8 @@ class test_TaskSet(AppCase):
             def apply_async(self, *args, **kwargs):
             def apply_async(self, *args, **kwargs):
                 applied[0] += 1
                 applied[0] += 1
 
 
-        ts = TaskSet([mocksubtask(MockTask, (i, i))
-                      for i in (2, 4, 8)], app=self.app)
+        ts = self.TaskSet([mocksubtask(self.MockTask, (i, i))
+                           for i in (2, 4, 8)], app=self.app)
         ts.apply_async()
         ts.apply_async()
         self.assertEqual(applied[0], 3)
         self.assertEqual(applied[0], 3)
 
 
@@ -157,7 +203,7 @@ class test_TaskSet(AppCase):
 
 
         # setting current_task
         # setting current_task
 
 
-        @self.app.task
+        @self.app.task(shared=False)
         def xyz():
         def xyz():
             pass
             pass
 
 
@@ -179,22 +225,22 @@ class test_TaskSet(AppCase):
             def apply(self, *args, **kwargs):
             def apply(self, *args, **kwargs):
                 applied[0] += 1
                 applied[0] += 1
 
 
-        ts = TaskSet([mocksubtask(MockTask, (i, i))
-                      for i in (2, 4, 8)], app=self.app)
+        ts = self.TaskSet([mocksubtask(self.MockTask, (i, i))
+                           for i in (2, 4, 8)], app=self.app)
         ts.apply()
         ts.apply()
         self.assertEqual(applied[0], 3)
         self.assertEqual(applied[0], 3)
 
 
     def test_set_app(self):
     def test_set_app(self):
-        ts = TaskSet([], app=self.app)
+        ts = self.TaskSet([], app=self.app)
         ts.app = 42
         ts.app = 42
         self.assertEqual(ts.app, 42)
         self.assertEqual(ts.app, 42)
 
 
     def test_set_tasks(self):
     def test_set_tasks(self):
-        ts = TaskSet([], app=self.app)
+        ts = self.TaskSet([], app=self.app)
         ts.tasks = [1, 2, 3]
         ts.tasks = [1, 2, 3]
         self.assertEqual(ts, [1, 2, 3])
         self.assertEqual(ts, [1, 2, 3])
 
 
     def test_set_Publisher(self):
     def test_set_Publisher(self):
-        ts = TaskSet([], app=self.app)
+        ts = self.TaskSet([], app=self.app)
         ts.Publisher = 42
         ts.Publisher = 42
         self.assertEqual(ts.Publisher, 42)
         self.assertEqual(ts.Publisher, 42)

+ 2 - 2
celery/tests/concurrency/test_concurrency.py

@@ -6,10 +6,10 @@ from itertools import count
 from mock import Mock
 from mock import Mock
 
 
 from celery.concurrency.base import apply_target, BasePool
 from celery.concurrency.base import apply_target, BasePool
-from celery.tests.case import Case
+from celery.tests.case import AppCase
 
 
 
 
-class test_BasePool(Case):
+class test_BasePool(AppCase):
 
 
     def test_apply_target(self):
     def test_apply_target(self):
 
 

+ 4 - 4
celery/tests/concurrency/test_eventlet.py

@@ -14,13 +14,13 @@ from celery.concurrency.eventlet import (
     TaskPool,
     TaskPool,
 )
 )
 
 
-from celery.tests.case import Case, mock_module, patch_many, skip_if_pypy
+from celery.tests.case import AppCase, mock_module, patch_many, skip_if_pypy
 
 
 
 
-class EventletCase(Case):
+class EventletCase(AppCase):
 
 
     @skip_if_pypy
     @skip_if_pypy
-    def setUp(self):
+    def setup(self):
         if is_pypy:
         if is_pypy:
             raise SkipTest('mock_modules not working on PyPy1.9')
             raise SkipTest('mock_modules not working on PyPy1.9')
         try:
         try:
@@ -30,7 +30,7 @@ class EventletCase(Case):
                 'eventlet not installed, skipping related tests.')
                 'eventlet not installed, skipping related tests.')
 
 
     @skip_if_pypy
     @skip_if_pypy
-    def tearDown(self):
+    def teardown(self):
         for mod in [mod for mod in sys.modules if mod.startswith('eventlet')]:
         for mod in [mod for mod in sys.modules if mod.startswith('eventlet')]:
             try:
             try:
                 del(sys.modules[mod])
                 del(sys.modules[mod])

+ 7 - 7
celery/tests/concurrency/test_gevent.py

@@ -14,7 +14,7 @@ from celery.concurrency.gevent import (
 )
 )
 
 
 from celery.tests.case import (
 from celery.tests.case import (
-    Case, mock_module, patch_many, skip_if_pypy,
+    AppCase, mock_module, patch_many, skip_if_pypy,
 )
 )
 
 
 gevent_modules = (
 gevent_modules = (
@@ -26,10 +26,10 @@ gevent_modules = (
 )
 )
 
 
 
 
-class GeventCase(Case):
+class GeventCase(AppCase):
 
 
     @skip_if_pypy
     @skip_if_pypy
-    def setUp(self):
+    def setup(self):
         try:
         try:
             self.gevent = __import__('gevent')
             self.gevent = __import__('gevent')
         except ImportError:
         except ImportError:
@@ -58,7 +58,7 @@ class test_gevent_patch(GeventCase):
                 monkey.patch_all = prev_monkey_patch
                 monkey.patch_all = prev_monkey_patch
 
 
 
 
-class test_Schedule(Case):
+class test_Schedule(AppCase):
 
 
     def test_sched(self):
     def test_sched(self):
         with mock_module(*gevent_modules):
         with mock_module(*gevent_modules):
@@ -88,7 +88,7 @@ class test_Schedule(Case):
                 g.cancel()
                 g.cancel()
 
 
 
 
-class test_TasKPool(Case):
+class test_TaskPool(AppCase):
 
 
     def test_pool(self):
     def test_pool(self):
         with mock_module(*gevent_modules):
         with mock_module(*gevent_modules):
@@ -115,7 +115,7 @@ class test_TasKPool(Case):
                 self.assertEqual(x.num_processes, 3)
                 self.assertEqual(x.num_processes, 3)
 
 
 
 
-class test_Timer(Case):
+class test_Timer(AppCase):
 
 
     def test_timer(self):
     def test_timer(self):
         with mock_module(*gevent_modules):
         with mock_module(*gevent_modules):
@@ -127,7 +127,7 @@ class test_Timer(Case):
             x.schedule.clear.assert_called_with()
             x.schedule.clear.assert_called_with()
 
 
 
 
-class test_apply_timeout(Case):
+class test_apply_timeout(AppCase):
 
 
     def test_apply_timeout(self):
     def test_apply_timeout(self):
 
 

+ 3 - 3
celery/tests/concurrency/test_pool.py

@@ -7,7 +7,7 @@ from nose import SkipTest
 
 
 from billiard.einfo import ExceptionInfo
 from billiard.einfo import ExceptionInfo
 
 
-from celery.tests.case import Case
+from celery.tests.case import AppCase
 
 
 
 
 def do_something(i):
 def do_something(i):
@@ -25,9 +25,9 @@ def raise_something(i):
         return ExceptionInfo()
         return ExceptionInfo()
 
 
 
 
-class test_TaskPool(Case):
+class test_TaskPool(AppCase):
 
 
-    def setUp(self):
+    def setup(self):
         try:
         try:
             __import__('multiprocessing')
             __import__('multiprocessing')
         except ImportError:
         except ImportError:

+ 2 - 2
celery/tests/concurrency/test_solo.py

@@ -4,10 +4,10 @@ import operator
 
 
 from celery.concurrency import solo
 from celery.concurrency import solo
 from celery.utils.functional import noop
 from celery.utils.functional import noop
-from celery.tests.case import Case
+from celery.tests.case import AppCase
 
 
 
 
-class test_solo_TaskPool(Case):
+class test_solo_TaskPool(AppCase):
 
 
     def test_on_start(self):
     def test_on_start(self):
         x = solo.TaskPool()
         x = solo.TaskPool()

+ 2 - 2
celery/tests/concurrency/test_threads.py

@@ -4,7 +4,7 @@ from mock import Mock
 
 
 from celery.concurrency.threads import NullDict, TaskPool, apply_target
 from celery.concurrency.threads import NullDict, TaskPool, apply_target
 
 
-from celery.tests.case import Case, mask_modules, mock_module
+from celery.tests.case import AppCase, Case, mask_modules, mock_module
 
 
 
 
 class test_NullDict(Case):
 class test_NullDict(Case):
@@ -16,7 +16,7 @@ class test_NullDict(Case):
             x['foo']
             x['foo']
 
 
 
 
-class test_TaskPool(Case):
+class test_TaskPool(AppCase):
 
 
     def test_without_threadpool(self):
     def test_without_threadpool(self):
 
 

+ 0 - 44
celery/tests/config.py

@@ -1,44 +0,0 @@
-from __future__ import absolute_import
-
-import os
-
-from kombu import Queue
-
-BROKER_URL = 'memory://'
-
-#: warn if config module not found
-os.environ['C_WNOCONF'] = 'yes'
-
-#: Don't want log output when running suite.
-CELERYD_HIJACK_ROOT_LOGGER = False
-
-CELERY_RESULT_BACKEND = 'cache'
-CELERY_CACHE_BACKEND = 'memory'
-CELERY_RESULT_DBURI = 'sqlite:///test.db'
-CELERY_SEND_TASK_ERROR_EMAILS = False
-
-CELERY_DEFAULT_QUEUE = 'testcelery'
-CELERY_DEFAULT_EXCHANGE = 'testcelery'
-CELERY_DEFAULT_ROUTING_KEY = 'testcelery'
-CELERY_QUEUES = (
-    Queue('testcelery', routing_key='testcelery'),
-)
-
-CELERY_ENABLE_UTC = True
-CELERY_TIMEZONE = 'UTC'
-
-CELERYD_LOG_COLOR = False
-
-# Mongo results tests (only executed if installed and running)
-CELERY_MONGODB_BACKEND_SETTINGS = {
-    'host': os.environ.get('MONGO_HOST') or 'localhost',
-    'port': os.environ.get('MONGO_PORT') or 27017,
-    'database': os.environ.get('MONGO_DB') or 'celery_unittests',
-    'taskmeta_collection': (os.environ.get('MONGO_TASKMETA_COLLECTION')
-                            or 'taskmeta_collection'),
-}
-if os.environ.get('MONGO_USER'):
-    CELERY_MONGODB_BACKEND_SETTINGS['user'] = os.environ.get('MONGO_USER')
-if os.environ.get('MONGO_PASSWORD'):
-    CELERY_MONGODB_BACKEND_SETTINGS['password'] = \
-        os.environ.get('MONGO_PASSWORD')

+ 24 - 26
celery/tests/contrib/test_abortable.py

@@ -1,51 +1,49 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
 
 
 from celery.contrib.abortable import AbortableTask, AbortableAsyncResult
 from celery.contrib.abortable import AbortableTask, AbortableAsyncResult
-from celery.result import AsyncResult
-from celery.tests.case import Case
+from celery.tests.case import AppCase
 
 
 
 
-class MyAbortableTask(AbortableTask):
+class test_AbortableTask(AppCase):
 
 
-    def run(self, **kwargs):
-        return True
+    def setup(self):
 
 
-
-class test_AbortableTask(Case):
+        @self.app.task(base=AbortableTask, shared=False)
+        def abortable():
+            return True
+        self.abortable = abortable
 
 
     def test_async_result_is_abortable(self):
     def test_async_result_is_abortable(self):
-        t = MyAbortableTask()
-        result = t.apply_async()
+        result = self.abortable.apply_async()
         tid = result.id
         tid = result.id
-        self.assertIsInstance(t.AsyncResult(tid), AbortableAsyncResult)
+        self.assertIsInstance(
+            self.abortable.AsyncResult(tid), AbortableAsyncResult,
+        )
 
 
     def test_is_not_aborted(self):
     def test_is_not_aborted(self):
-        t = MyAbortableTask()
-        t.push_request()
+        self.abortable.push_request()
         try:
         try:
-            result = t.apply_async()
+            result = self.abortable.apply_async()
             tid = result.id
             tid = result.id
-            self.assertFalse(t.is_aborted(task_id=tid))
+            self.assertFalse(self.abortable.is_aborted(task_id=tid))
         finally:
         finally:
-            t.pop_request()
+            self.abortable.pop_request()
 
 
     def test_is_aborted_not_abort_result(self):
     def test_is_aborted_not_abort_result(self):
-        t = MyAbortableTask()
-        t.AsyncResult = AsyncResult
-        t.push_request()
+        self.abortable.AsyncResult = self.app.AsyncResult
+        self.abortable.push_request()
         try:
         try:
-            t.request.id = 'foo'
-            self.assertFalse(t.is_aborted())
+            self.abortable.request.id = 'foo'
+            self.assertFalse(self.abortable.is_aborted())
         finally:
         finally:
-            t.pop_request()
+            self.abortable.pop_request()
 
 
     def test_abort_yields_aborted(self):
     def test_abort_yields_aborted(self):
-        t = MyAbortableTask()
-        t.push_request()
+        self.abortable.push_request()
         try:
         try:
-            result = t.apply_async()
+            result = self.abortable.apply_async()
             result.abort()
             result.abort()
             tid = result.id
             tid = result.id
-            self.assertTrue(t.is_aborted(task_id=tid))
+            self.assertTrue(self.abortable.is_aborted(task_id=tid))
         finally:
         finally:
-            t.pop_request()
+            self.abortable.pop_request()

+ 1 - 1
celery/tests/contrib/test_methods.py

@@ -14,7 +14,7 @@ class test_task_method(AppCase):
             def __init__(self):
             def __init__(self):
                 self.state = 0
                 self.state = 0
 
 
-            @self.app.task(filter=task_method)
+            @self.app.task(shared=False, filter=task_method)
             def add(self, x):
             def add(self, x):
                 self.state += x
                 self.state += x
 
 

+ 5 - 5
celery/tests/contrib/test_migrate.py

@@ -26,7 +26,7 @@ from celery.contrib.migrate import (
     move,
     move,
 )
 )
 from celery.utils.encoding import bytes_t, ensure_bytes
 from celery.utils.encoding import bytes_t, ensure_bytes
-from celery.tests.case import AppCase, Case, Mock, override_stdouts
+from celery.tests.case import AppCase, Mock, override_stdouts
 
 
 # hack to ignore error at shutdown
 # hack to ignore error at shutdown
 QoS.restore_at_shutdown = False
 QoS.restore_at_shutdown = False
@@ -52,7 +52,7 @@ def Message(body, exchange='exchange', routing_key='rkey',
     )
     )
 
 
 
 
-class test_State(Case):
+class test_State(AppCase):
 
 
     def test_strtotal(self):
     def test_strtotal(self):
         x = State()
         x = State()
@@ -178,7 +178,7 @@ class test_start_filter(AppCase):
             self.assertTrue(stop_filtering_raised)
             self.assertTrue(stop_filtering_raised)
 
 
 
 
-class test_filter_callback(Case):
+class test_filter_callback(AppCase):
 
 
     def test_filter(self):
     def test_filter(self):
         callback = Mock()
         callback = Mock()
@@ -193,7 +193,7 @@ class test_filter_callback(Case):
         callback.assert_called_with(t1, message)
         callback.assert_called_with(t1, message)
 
 
 
 
-class test_utils(Case):
+class test_utils(AppCase):
 
 
     def test_task_id_in(self):
     def test_task_id_in(self):
         self.assertTrue(task_id_in(['A'], {'id': 'A'}, Mock()))
         self.assertTrue(task_id_in(['A'], {'id': 'A'}, Mock()))
@@ -243,7 +243,7 @@ class test_utils(Case):
             )
             )
 
 
 
 
-class test_migrate_task(Case):
+class test_migrate_task(AppCase):
 
 
     def test_removes_compression_header(self):
     def test_removes_compression_header(self):
         x = Message('foo', compression='zlib')
         x = Message('foo', compression='zlib')

+ 10 - 13
celery/tests/events/test_events.py

@@ -4,7 +4,6 @@ import socket
 
 
 from mock import Mock
 from mock import Mock
 
 
-from celery import Celery
 from celery.events import Event
 from celery.events import Event
 from celery.tests.case import AppCase
 from celery.tests.case import AppCase
 
 
@@ -41,21 +40,19 @@ class test_Event(AppCase):
 class test_EventDispatcher(AppCase):
 class test_EventDispatcher(AppCase):
 
 
     def test_redis_uses_fanout_exchange(self):
     def test_redis_uses_fanout_exchange(self):
-        with Celery(set_as_current=False) as app:
-            app.connection = Mock()
-            conn = app.connection.return_value = Mock()
-            conn.transport.driver_type = 'redis'
+        self.app.connection = Mock()
+        conn = self.app.connection.return_value = Mock()
+        conn.transport.driver_type = 'redis'
 
 
-            dispatcher = app.events.Dispatcher(conn, enabled=False)
-            self.assertEqual(dispatcher.exchange.type, 'fanout')
+        dispatcher = self.app.events.Dispatcher(conn, enabled=False)
+        self.assertEqual(dispatcher.exchange.type, 'fanout')
 
 
     def test_others_use_topic_exchange(self):
     def test_others_use_topic_exchange(self):
-        with Celery(set_as_current=False) as app:
-            app.connection = Mock()
-            conn = app.connection.return_value = Mock()
-            conn.transport.driver_type = 'amqp'
-            dispatcher = app.events.Dispatcher(conn, enabled=False)
-            self.assertEqual(dispatcher.exchange.type, 'topic')
+        self.app.connection = Mock()
+        conn = self.app.connection.return_value = Mock()
+        conn.transport.driver_type = 'amqp'
+        dispatcher = self.app.events.Dispatcher(conn, enabled=False)
+        self.assertEqual(dispatcher.exchange.type, 'topic')
 
 
     def test_takes_channel_connection(self):
     def test_takes_channel_connection(self):
         x = self.app.events.Dispatcher(channel=Mock())
         x = self.app.events.Dispatcher(channel=Mock())

+ 4 - 4
celery/tests/events/test_state.py

@@ -18,7 +18,7 @@ from celery.events.state import (
 )
 )
 from celery.five import range
 from celery.five import range
 from celery.utils import uuid
 from celery.utils import uuid
-from celery.tests.case import Case
+from celery.tests.case import AppCase
 
 
 
 
 class replay(object):
 class replay(object):
@@ -152,7 +152,7 @@ class ev_snapshot(replay):
                                uuid=uuid(), hostname=worker))
                                uuid=uuid(), hostname=worker))
 
 
 
 
-class test_Worker(Case):
+class test_Worker(AppCase):
 
 
     def test_equality(self):
     def test_equality(self):
         self.assertEqual(Worker(hostname='foo').hostname, 'foo')
         self.assertEqual(Worker(hostname='foo').hostname, 'foo')
@@ -192,7 +192,7 @@ class test_Worker(Case):
         self.assertEqual(len(worker.heartbeats), 1)
         self.assertEqual(len(worker.heartbeats), 1)
 
 
 
 
-class test_Task(Case):
+class test_Task(AppCase):
 
 
     def test_equality(self):
     def test_equality(self):
         self.assertEqual(Task(uuid='foo').uuid, 'foo')
         self.assertEqual(Task(uuid='foo').uuid, 'foo')
@@ -265,7 +265,7 @@ class test_Task(Case):
         self.assertTrue(repr(Task(uuid='xxx', name='tasks.add')))
         self.assertTrue(repr(Task(uuid='xxx', name='tasks.add')))
 
 
 
 
-class test_State(Case):
+class test_State(AppCase):
 
 
     def test_repr(self):
     def test_repr(self):
         self.assertTrue(repr(State()))
         self.assertTrue(repr(State()))

+ 5 - 7
celery/tests/fixups/test_django.py

@@ -5,7 +5,6 @@ import os
 from contextlib import contextmanager
 from contextlib import contextmanager
 from mock import Mock, patch
 from mock import Mock, patch
 
 
-from celery import Celery
 from celery.fixups.django import (
 from celery.fixups.django import (
     _maybe_close_fd,
     _maybe_close_fd,
     fixup,
     fixup,
@@ -62,10 +61,9 @@ class test_DjangoFixup(AppCase):
             self.assertIsNone(DjangoFixup(self.app)._close_old_connections)
             self.assertIsNone(DjangoFixup(self.app)._close_old_connections)
 
 
     def test_install(self):
     def test_install(self):
-        app = Celery(set_as_current=False)
-        app.conf = {'CELERY_DB_REUSE_MAX': None}
-        app.loader = Mock()
-        with self.fixup_context(app) as (f, _, _):
+        self.app.conf = {'CELERY_DB_REUSE_MAX': None}
+        self.app.loader = Mock()
+        with self.fixup_context(self.app) as (f, _, _):
             with patch_many('os.getcwd', 'sys.path',
             with patch_many('os.getcwd', 'sys.path',
                             'celery.fixups.django.signals') as (cw, p, sigs):
                             'celery.fixups.django.signals') as (cw, p, sigs):
                 cw.return_value = '/opt/vandelay'
                 cw.return_value = '/opt/vandelay'
@@ -80,8 +78,8 @@ class test_DjangoFixup(AppCase):
                 sigs.worker_process_init.connect.assert_called_with(
                 sigs.worker_process_init.connect.assert_called_with(
                     f.on_worker_process_init,
                     f.on_worker_process_init,
                 )
                 )
-                self.assertEqual(app.loader.now, f.now)
-                self.assertEqual(app.loader.mail_admins, f.mail_admins)
+                self.assertEqual(self.app.loader.now, f.now)
+                self.assertEqual(self.app.loader.mail_admins, f.mail_admins)
                 p.append.assert_called_with('/opt/vandelay')
                 p.append.assert_called_with('/opt/vandelay')
 
 
     def test_now(self):
     def test_now(self):

+ 13 - 9
celery/tests/functional/case.py

@@ -11,8 +11,9 @@ import traceback
 from itertools import count
 from itertools import count
 from time import time
 from time import time
 
 
+from celery import current_app
 from celery.exceptions import TimeoutError
 from celery.exceptions import TimeoutError
-from celery.task.control import ping, flatten_reply, inspect
+from celery.app.control import flatten_reply
 from celery.utils.imports import qualname
 from celery.utils.imports import qualname
 
 
 from celery.tests.case import Case
 from celery.tests.case import Case
@@ -39,9 +40,10 @@ class Worker(object):
     worker_ids = count(1)
     worker_ids = count(1)
     _shutdown_called = False
     _shutdown_called = False
 
 
-    def __init__(self, hostname, loglevel='error'):
+    def __init__(self, hostname, loglevel='error', app=None):
         self.hostname = hostname
         self.hostname = hostname
         self.loglevel = loglevel
         self.loglevel = loglevel
+        self.app = app or current_app._get_current_object()
 
 
     def start(self):
     def start(self):
         if not self.started:
         if not self.started:
@@ -51,16 +53,17 @@ class Worker(object):
     def _fork_and_exec(self):
     def _fork_and_exec(self):
         pid = os.fork()
         pid = os.fork()
         if pid == 0:
         if pid == 0:
-            from celery import current_app
-            current_app.worker_main(['worker', '--loglevel=INFO',
-                                               '-n', self.hostname,
-                                               '-P', 'solo'])
+            self.app.worker_main(['worker', '--loglevel=INFO',
+                                  '-n', self.hostname,
+                                  '-P', 'solo'])
             os._exit(0)
             os._exit(0)
         self.pid = pid
         self.pid = pid
 
 
+    def ping(self, *args, **kwargs):
+        return self.app.control.ping(*args, **kwargs)
+
     def is_alive(self, timeout=1):
     def is_alive(self, timeout=1):
-        r = ping(destination=[self.hostname],
-                 timeout=timeout)
+        r = self.ping(destination=[self.hostname], timeout=timeout)
         return self.hostname in flatten_reply(r)
         return self.hostname in flatten_reply(r)
 
 
     def wait_until_started(self, timeout=10, interval=0.5):
     def wait_until_started(self, timeout=10, interval=0.5):
@@ -124,7 +127,8 @@ class WorkerCase(Case):
         self.assertTrue(self.worker.is_alive)
         self.assertTrue(self.worker.is_alive)
 
 
     def inspect(self, timeout=1):
     def inspect(self, timeout=1):
-        return inspect([self.worker.hostname], timeout=timeout)
+        return self.app.control.inspect([self.worker.hostname],
+                                        timeout=timeout)
 
 
     def my_response(self, response):
     def my_response(self, response):
         return flatten_reply(response)[self.worker.hostname]
         return flatten_reply(response)[self.worker.hostname]

+ 25 - 37
celery/tests/security/test_security.py

@@ -32,7 +32,7 @@ from celery.tests.case import mock_open
 
 
 class test_security(SecurityCase):
 class test_security(SecurityCase):
 
 
-    def tearDown(self):
+    def teardown(self):
         registry._disabled_content_types.clear()
         registry._disabled_content_types.clear()
 
 
     def test_disable_insecure_serializers(self):
     def test_disable_insecure_serializers(self):
@@ -59,14 +59,10 @@ class test_security(SecurityCase):
         disabled = registry._disabled_content_types
         disabled = registry._disabled_content_types
         self.assertEqual(0, len(disabled))
         self.assertEqual(0, len(disabled))
 
 
-        prev, self.app.conf.CELERY_TASK_SERIALIZER = (
-            self.app.conf.CELERY_TASK_SERIALIZER, 'json')
-        try:
-            self.app.setup_security()
-            self.assertIn('application/x-python-serialize', disabled)
-            disabled.clear()
-        finally:
-            self.app.conf.CELERY_TASK_SERIALIZER = prev
+        self.app.conf.CELERY_TASK_SERIALIZER = 'json'
+        self.app.setup_security()
+        self.assertIn('application/x-python-serialize', disabled)
+        disabled.clear()
 
 
     @patch('celery.security.register_auth')
     @patch('celery.security.register_auth')
     @patch('celery.security._disable_insecure_serializers')
     @patch('celery.security._disable_insecure_serializers')
@@ -81,39 +77,31 @@ class test_security(SecurityCase):
             finally:
             finally:
                 calls[0] += 1
                 calls[0] += 1
 
 
-        prev, self.app.conf.CELERY_TASK_SERIALIZER = (
-            self.app.conf.CELERY_TASK_SERIALIZER, 'auth')
-        try:
-            with mock_open(side_effect=effect):
-                with patch('celery.security.registry') as registry:
-                    store = Mock()
-                    self.app.setup_security(['json'], key, cert, store)
-                    dis.assert_called_with(['json'])
-                    reg.assert_called_with('A', 'B', store, 'sha1', 'json')
-                    registry._set_default_serializer.assert_called_with('auth')
-        finally:
-            self.app.conf.CELERY_TASK_SERIALIZER = prev
+        self.app.conf.CELERY_TASK_SERIALIZER = 'auth'
+        with mock_open(side_effect=effect):
+            with patch('celery.security.registry') as registry:
+                store = Mock()
+                self.app.setup_security(['json'], key, cert, store)
+                dis.assert_called_with(['json'])
+                reg.assert_called_with('A', 'B', store, 'sha1', 'json')
+                registry._set_default_serializer.assert_called_with('auth')
 
 
     def test_security_conf(self):
     def test_security_conf(self):
-        prev, self.app.conf.CELERY_TASK_SERIALIZER = (
-            self.app.conf.CELERY_TASK_SERIALIZER, 'auth')
-        try:
-            with self.assertRaises(ImproperlyConfigured):
-                self.app.setup_security()
+        self.app.conf.CELERY_TASK_SERIALIZER = 'auth'
+        with self.assertRaises(ImproperlyConfigured):
+            self.app.setup_security()
 
 
-            _import = builtins.__import__
+        _import = builtins.__import__
 
 
-            def import_hook(name, *args, **kwargs):
-                if name == 'OpenSSL':
-                    raise ImportError
-                return _import(name, *args, **kwargs)
+        def import_hook(name, *args, **kwargs):
+            if name == 'OpenSSL':
+                raise ImportError
+            return _import(name, *args, **kwargs)
 
 
-            builtins.__import__ = import_hook
-            with self.assertRaises(ImproperlyConfigured):
-                self.app.setup_security()
-            builtins.__import__ = _import
-        finally:
-            self.app.conf.CELERY_TASK_SERIALIZER = prev
+        builtins.__import__ = import_hook
+        with self.assertRaises(ImproperlyConfigured):
+            self.app.setup_security()
+        builtins.__import__ = _import
 
 
     def test_reraise_errors(self):
     def test_reraise_errors(self):
         with self.assertRaises(SecurityError):
         with self.assertRaises(SecurityError):

+ 13 - 20
celery/tests/tasks/test_canvas.py

@@ -140,20 +140,17 @@ class test_Signature(CanvasCase):
     def test_election(self):
     def test_election(self):
         x = self.add.s(2, 2)
         x = self.add.s(2, 2)
         x.freeze('foo')
         x.freeze('foo')
-        prev, x.type.app.control = x.type.app.control, Mock()
-        try:
-            r = x.election()
-            self.assertTrue(x.type.app.control.election.called)
-            self.assertEqual(r.id, 'foo')
-        finally:
-            x.type.app.control = prev
-
-    def test_AsyncResult_when_not_registerd(self):
-        s = subtask('xxx.not.registered')
+        x.type.app.control = Mock()
+        r = x.election()
+        self.assertTrue(x.type.app.control.election.called)
+        self.assertEqual(r.id, 'foo')
+
+    def test_AsyncResult_when_not_registered(self):
+        s = subtask('xxx.not.registered', app=self.app)
         self.assertTrue(s.AsyncResult)
         self.assertTrue(s.AsyncResult)
 
 
     def test_apply_async_when_not_registered(self):
     def test_apply_async_when_not_registered(self):
-        s = subtask('xxx.not.registered')
+        s = subtask('xxx.not.registered', app=self.app)
         self.assertTrue(s._apply_async)
         self.assertTrue(s._apply_async)
 
 
 
 
@@ -178,7 +175,9 @@ class test_chunks(CanvasCase):
 
 
     def test_chunks(self):
     def test_chunks(self):
         x = self.add.chunks(range(100), 10)
         x = self.add.chunks(range(100), 10)
-        self.assertEqual(chunks.from_dict(dict(x)), x)
+        self.assertEqual(
+            dict(chunks.from_dict(dict(x), app=self.app)), dict(x),
+        )
 
 
         self.assertTrue(x.group())
         self.assertTrue(x.group())
         self.assertEqual(len(x.group().tasks), 10)
         self.assertEqual(len(x.group().tasks), 10)
@@ -193,10 +192,7 @@ class test_chunks(CanvasCase):
         gr.assert_called_with()
         gr.assert_called_with()
 
 
         self.app.conf.CELERY_ALWAYS_EAGER = True
         self.app.conf.CELERY_ALWAYS_EAGER = True
-        try:
-            chunks.apply_chunks(**x['kwargs'])
-        finally:
-            self.app.conf.CELERY_ALWAYS_EAGER = False
+        chunks.apply_chunks(app=self.app, **x['kwargs'])
 
 
 
 
 class test_chain(CanvasCase):
 class test_chain(CanvasCase):
@@ -214,10 +210,7 @@ class test_chain(CanvasCase):
 
 
     def test_always_eager(self):
     def test_always_eager(self):
         self.app.conf.CELERY_ALWAYS_EAGER = True
         self.app.conf.CELERY_ALWAYS_EAGER = True
-        try:
-            self.assertEqual(~(self.add.s(4, 4) | self.add.s(8)), 16)
-        finally:
-            self.app.conf.CELERY_ALWAYS_EAGER = False
+        self.assertEqual(~(self.add.s(4, 4) | self.add.s(8)), 16)
 
 
     def test_apply(self):
     def test_apply(self):
         x = chain(self.add.s(4, 4), self.add.s(8), self.add.s(10))
         x = chain(self.add.s(4, 4), self.add.s(8), self.add.s(10))

+ 13 - 19
celery/tests/tasks/test_chord.py

@@ -150,7 +150,7 @@ class test_unlock_chord_task(ChordCase):
                     assert self.app.tasks['celery.chord_unlock'] is unlock
                     assert self.app.tasks['celery.chord_unlock'] is unlock
                     unlock(
                     unlock(
                         'group_id', callback_s,
                         'group_id', callback_s,
-                        result=[AsyncResult(r) for r in ['1', 2, 3]],
+                        result=[self.app.AsyncResult(r) for r in ['1', 2, 3]],
                         GroupResult=ResultCls, **kwargs
                         GroupResult=ResultCls, **kwargs
                     )
                     )
                 finally:
                 finally:
@@ -178,22 +178,19 @@ class test_chord(ChordCase):
     def test_eager(self):
     def test_eager(self):
         from celery import chord
         from celery import chord
 
 
-        @self.app.task()
+        @self.app.task(shared=False)
         def addX(x, y):
         def addX(x, y):
             return x + y
             return x + y
 
 
-        @self.app.task()
+        @self.app.task(shared=False)
         def sumX(n):
         def sumX(n):
             return sum(n)
             return sum(n)
 
 
         self.app.conf.CELERY_ALWAYS_EAGER = True
         self.app.conf.CELERY_ALWAYS_EAGER = True
-        try:
-            x = chord(addX.s(i, i) for i in range(10))
-            body = sumX.s()
-            result = x(body)
-            self.assertEqual(result.get(), sum(i + i for i in range(10)))
-        finally:
-            self.app.conf.CELERY_ALWAYS_EAGER = False
+        x = chord(addX.s(i, i) for i in range(10))
+        body = sumX.s()
+        result = x(body)
+        self.assertEqual(result.get(), sum(i + i for i in range(10)))
 
 
     def test_apply(self):
     def test_apply(self):
         self.app.conf.CELERY_ALWAYS_EAGER = False
         self.app.conf.CELERY_ALWAYS_EAGER = False
@@ -219,15 +216,12 @@ class test_chord(ChordCase):
 class test_Chord_task(ChordCase):
 class test_Chord_task(ChordCase):
 
 
     def test_run(self):
     def test_run(self):
-        prev, self.app.backend = self.app.backend, Mock()
+        self.app.backend = Mock()
         self.app.backend.cleanup = Mock()
         self.app.backend.cleanup = Mock()
         self.app.backend.cleanup.__name__ = 'cleanup'
         self.app.backend.cleanup.__name__ = 'cleanup'
-        try:
-            Chord = self.app.tasks['celery.chord']
+        Chord = self.app.tasks['celery.chord']
 
 
-            body = dict()
-            Chord(group(self.add.subtask((i, i)) for i in range(5)), body)
-            Chord([self.add.subtask((j, j)) for j in range(5)], body)
-            self.assertEqual(self.app.backend.on_chord_apply.call_count, 2)
-        finally:
-            self.app.backend = prev
+        body = dict()
+        Chord(group(self.add.subtask((i, i)) for i in range(5)), body)
+        Chord([self.add.subtask((j, j)) for j in range(5)], body)
+        self.assertEqual(self.app.backend.on_chord_apply.call_count, 2)

+ 3 - 3
celery/tests/tasks/test_context.py

@@ -3,8 +3,8 @@ from __future__ import absolute_import
 
 
 from collections import Callable
 from collections import Callable
 
 
-from celery.task.base import Context
-from celery.tests.case import Case
+from celery.app.task import Context
+from celery.tests.case import AppCase
 
 
 
 
 # Retreive the values of all context attributes as a
 # Retreive the values of all context attributes as a
@@ -22,7 +22,7 @@ def get_context_as_dict(ctx, getter=getattr):
 default_context = get_context_as_dict(Context())
 default_context = get_context_as_dict(Context())
 
 
 
 
-class test_Context(Case):
+class test_Context(AppCase):
 
 
     def test_default_context(self):
     def test_default_context(self):
         # A bit of a tautological test, since it uses the same
         # A bit of a tautological test, since it uses the same

+ 16 - 16
celery/tests/tasks/test_result.py

@@ -10,15 +10,12 @@ from celery.result import (
     AsyncResult,
     AsyncResult,
     EagerResult,
     EagerResult,
     TaskSetResult,
     TaskSetResult,
-    ResultSet,
-    #GroupResult,
     from_serializable,
     from_serializable,
 )
 )
 from celery.utils import uuid
 from celery.utils import uuid
 from celery.utils.serialization import pickle
 from celery.utils.serialization import pickle
 
 
-from celery.tests.case import AppCase
-from celery.tests.case import skip_if_quick
+from celery.tests.case import AppCase, depends_on_current_app, skip_if_quick
 
 
 
 
 def mock_task(name, state, result):
 def mock_task(name, state, result):
@@ -56,7 +53,7 @@ class test_AsyncResult(AppCase):
         for task in (self.task1, self.task2, self.task3, self.task4):
         for task in (self.task1, self.task2, self.task3, self.task4):
             save_result(self.app, task)
             save_result(self.app, task)
 
 
-        @self.app.task()
+        @self.app.task(shared=False)
         def mytask():
         def mytask():
             pass
             pass
         self.mytask = mytask
         self.mytask = mytask
@@ -150,6 +147,7 @@ class test_AsyncResult(AppCase):
     def test_eq_not_implemented(self):
     def test_eq_not_implemented(self):
         self.assertFalse(self.app.AsyncResult('1') == object())
         self.assertFalse(self.app.AsyncResult('1') == object())
 
 
+    @depends_on_current_app
     def test_reduce(self):
     def test_reduce(self):
         a1 = self.app.AsyncResult('uuid', task_name=self.mytask.name)
         a1 = self.app.AsyncResult('uuid', task_name=self.mytask.name)
         restored = pickle.loads(pickle.dumps(a1))
         restored = pickle.loads(pickle.dumps(a1))
@@ -261,15 +259,15 @@ class test_AsyncResult(AppCase):
 class test_ResultSet(AppCase):
 class test_ResultSet(AppCase):
 
 
     def test_resultset_repr(self):
     def test_resultset_repr(self):
-        self.assertTrue(repr(ResultSet(
+        self.assertTrue(repr(self.app.ResultSet(
             [self.app.AsyncResult(t) for t in ['1', '2', '3']])))
             [self.app.AsyncResult(t) for t in ['1', '2', '3']])))
 
 
     def test_eq_other(self):
     def test_eq_other(self):
-        self.assertFalse(ResultSet([1, 3, 3]) == 1)
-        self.assertTrue(ResultSet([1]) == ResultSet([1]))
+        self.assertFalse(self.app.ResultSet([1, 3, 3]) == 1)
+        self.assertTrue(self.app.ResultSet([1]) == self.app.ResultSet([1]))
 
 
     def test_get(self):
     def test_get(self):
-        x = ResultSet([self.app.AsyncResult(t) for t in [1, 2, 3]])
+        x = self.app.ResultSet([self.app.AsyncResult(t) for t in [1, 2, 3]])
         b = x.results[0].backend = Mock()
         b = x.results[0].backend = Mock()
         b.supports_native_join = False
         b.supports_native_join = False
         x.join_native = Mock()
         x.join_native = Mock()
@@ -281,7 +279,7 @@ class test_ResultSet(AppCase):
         self.assertTrue(x.join_native.called)
         self.assertTrue(x.join_native.called)
 
 
     def test_add(self):
     def test_add(self):
-        x = ResultSet([1])
+        x = self.app.ResultSet([1])
         x.add(2)
         x.add(2)
         self.assertEqual(len(x), 2)
         self.assertEqual(len(x), 2)
         x.add(2)
         x.add(2)
@@ -311,7 +309,7 @@ class test_ResultSet(AppCase):
         ready.return_value = False
         ready.return_value = False
         ready.side_effect = se
         ready.side_effect = se
 
 
-        x = ResultSet([r1, r2])
+        x = self.app.ResultSet([r1, r2])
         with self.dummy_copy():
         with self.dummy_copy():
             with patch('celery.result.time') as _time:
             with patch('celery.result.time') as _time:
                 with self.assertRaises(KeyError):
                 with self.assertRaises(KeyError):
@@ -330,14 +328,14 @@ class test_ResultSet(AppCase):
         r1 = self.app.AsyncResult(uuid)
         r1 = self.app.AsyncResult(uuid)
         r1.ready = Mock()
         r1.ready = Mock()
         r1.ready.return_value = False
         r1.ready.return_value = False
-        x = ResultSet([r1])
+        x = self.app.ResultSet([r1])
         with self.dummy_copy():
         with self.dummy_copy():
             with patch('celery.result.time'):
             with patch('celery.result.time'):
                 with self.assertRaises(TimeoutError):
                 with self.assertRaises(TimeoutError):
                     list(x.iterate(timeout=1))
                     list(x.iterate(timeout=1))
 
 
     def test_add_discard(self):
     def test_add_discard(self):
-        x = ResultSet([])
+        x = self.app.ResultSet([])
         x.add(self.app.AsyncResult('1'))
         x.add(self.app.AsyncResult('1'))
         self.assertIn(self.app.AsyncResult('1'), x.results)
         self.assertIn(self.app.AsyncResult('1'), x.results)
         x.discard(self.app.AsyncResult('1'))
         x.discard(self.app.AsyncResult('1'))
@@ -348,7 +346,7 @@ class test_ResultSet(AppCase):
         x.update([self.app.AsyncResult('2')])
         x.update([self.app.AsyncResult('2')])
 
 
     def test_clear(self):
     def test_clear(self):
-        x = ResultSet([])
+        x = self.app.ResultSet([])
         r = x.results
         r = x.results
         x.clear()
         x.clear()
         self.assertIs(x.results, r)
         self.assertIs(x.results, r)
@@ -432,6 +430,7 @@ class test_GroupResult(AppCase):
             uuid(), make_mock_group(self.app, self.size),
             uuid(), make_mock_group(self.app, self.size),
         )
         )
 
 
+    @depends_on_current_app
     def test_is_pickleable(self):
     def test_is_pickleable(self):
         ts = self.app.GroupResult(uuid(), [self.app.AsyncResult(uuid())])
         ts = self.app.GroupResult(uuid(), [self.app.AsyncResult(uuid())])
         self.assertEqual(pickle.loads(pickle.dumps(ts)), ts)
         self.assertEqual(pickle.loads(pickle.dumps(ts)), ts)
@@ -444,6 +443,7 @@ class test_GroupResult(AppCase):
     def test_eq_other(self):
     def test_eq_other(self):
         self.assertFalse(self.ts == 1)
         self.assertFalse(self.ts == 1)
 
 
+    @depends_on_current_app
     def test_reduce(self):
     def test_reduce(self):
         self.assertTrue(pickle.loads(pickle.dumps(self.ts)))
         self.assertTrue(pickle.loads(pickle.dumps(self.ts)))
 
 
@@ -660,7 +660,7 @@ class test_EagerResult(AppCase):
 
 
     def setup(self):
     def setup(self):
 
 
-        @self.app.task
+        @self.app.task(shared=False)
         def raising(x, y):
         def raising(x, y):
             raise KeyError(x, y)
             raise KeyError(x, y)
         self.raising = raising
         self.raising = raising
@@ -703,7 +703,7 @@ class test_serializable(AppCase):
 
 
     def test_compat(self):
     def test_compat(self):
         uid = uuid()
         uid = uuid()
-        x = from_serializable([uid, []])
+        x = from_serializable([uid, []], app=self.app)
         self.assertEqual(x.id, uid)
         self.assertEqual(x.id, uid)
 
 
     def test_GroupResult(self):
     def test_GroupResult(self):

+ 49 - 815
celery/tests/tasks/test_tasks.py

@@ -1,29 +1,20 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
-import time
+
 from collections import Callable
 from collections import Callable
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
-from functools import wraps
 from mock import patch
 from mock import patch
-from nose import SkipTest
-from pickle import loads, dumps
 
 
 from kombu import Queue
 from kombu import Queue
 
 
 from celery import Task
 from celery import Task
 
 
-from celery.task import (
-    periodic_task,
-    PeriodicTask
-)
 from celery.exceptions import RetryTaskError
 from celery.exceptions import RetryTaskError
-from celery.execute import send_task
 from celery.five import items, range, string_t
 from celery.five import items, range, string_t
 from celery.result import EagerResult
 from celery.result import EagerResult
-from celery.schedules import crontab, crontab_parser, ParseException
 from celery.utils import uuid
 from celery.utils import uuid
-from celery.utils.timeutils import parse_iso8601, timedelta_seconds
+from celery.utils.timeutils import parse_iso8601
 
 
-from celery.tests.case import AppCase
+from celery.tests.case import AppCase, depends_on_current_app
 
 
 
 
 def return_True(*args, **kwargs):
 def return_True(*args, **kwargs):
@@ -49,8 +40,7 @@ class MockApplyTask(Task):
 class TasksCase(AppCase):
 class TasksCase(AppCase):
 
 
     def setup(self):
     def setup(self):
-
-        self.return_True_task = self.app.task(shared=False)(return_True)
+        self.mytask = self.app.task(shared=False)(return_True)
 
 
         @self.app.task(bind=True, count=0, shared=False)
         @self.app.task(bind=True, count=0, shared=False)
         def increment_counter(self, increment_by=1):
         def increment_counter(self, increment_by=1):
@@ -225,8 +215,8 @@ class test_tasks(TasksCase):
     def now(self):
     def now(self):
         return self.app.now()
         return self.app.now()
 
 
+    @depends_on_current_app
     def test_unpickle_task(self):
     def test_unpickle_task(self):
-        self.app.set_current()
         import pickle
         import pickle
 
 
         @self.app.task(shared=True)
         @self.app.task(shared=True)
@@ -234,10 +224,6 @@ class test_tasks(TasksCase):
             pass
             pass
         self.assertIs(pickle.loads(pickle.dumps(xxx)), xxx.app.tasks[xxx.name])
         self.assertIs(pickle.loads(pickle.dumps(xxx)), xxx.app.tasks[xxx.name])
 
 
-    def create_task(self, name):
-        return self.app.task(__module__=self.__module__,
-                             shared=False, name=name)(return_True)
-
     def test_AsyncResult(self):
     def test_AsyncResult(self):
         task_id = uuid()
         task_id = uuid()
         result = self.retry_task.AsyncResult(task_id)
         result = self.retry_task.AsyncResult(task_id)
@@ -265,6 +251,7 @@ class test_tasks(TasksCase):
     def test_incomplete_task_cls(self):
     def test_incomplete_task_cls(self):
 
 
         class IncompleteTask(Task):
         class IncompleteTask(Task):
+            app = self.app
             name = 'c.unittest.t.itask'
             name = 'c.unittest.t.itask'
 
 
         with self.assertRaises(NotImplementedError):
         with self.assertRaises(NotImplementedError):
@@ -279,11 +266,11 @@ class test_tasks(TasksCase):
             self.increment_counter.apply_async('str', {})
             self.increment_counter.apply_async('str', {})
 
 
     def test_regular_task(self):
     def test_regular_task(self):
-        T1 = self.create_task('c.unittest.t.t1')
-        self.assertIsInstance(T1, Task)
-        self.assertTrue(T1.run())
-        self.assertTrue(isinstance(T1, Callable), 'Task class is callable()')
-        self.assertTrue(T1(), 'Task class runs run() when called')
+        self.assertIsInstance(self.mytask, Task)
+        self.assertTrue(self.mytask.run())
+        self.assertTrue(isinstance(self.mytask, Callable),
+                        'Task class is callable()')
+        self.assertTrue(self.mytask(), 'Task class runs run() when called')
 
 
         with self.app.connection_or_acquire() as conn:
         with self.app.connection_or_acquire() as conn:
             consumer = self.app.amqp.TaskConsumer(conn)
             consumer = self.app.amqp.TaskConsumer(conn)
@@ -294,54 +281,57 @@ class test_tasks(TasksCase):
             self.app.amqp.TaskConsumer(conn, queues=[Queue('foo')])
             self.app.amqp.TaskConsumer(conn, queues=[Queue('foo')])
 
 
             # Without arguments.
             # Without arguments.
-            presult = T1.delay()
-            self.assertNextTaskDataEqual(consumer, presult, T1.name)
+            presult = self.mytask.delay()
+            self.assertNextTaskDataEqual(consumer, presult, self.mytask.name)
 
 
             # With arguments.
             # With arguments.
-            presult2 = T1.apply_async(kwargs=dict(name='George Costanza'))
+            presult2 = self.mytask.apply_async(
+                kwargs=dict(name='George Costanza'),
+            )
             self.assertNextTaskDataEqual(
             self.assertNextTaskDataEqual(
-                consumer, presult2, T1.name, name='George Costanza',
+                consumer, presult2, self.mytask.name, name='George Costanza',
             )
             )
 
 
             # send_task
             # send_task
-            sresult = send_task(T1.name, kwargs=dict(name='Elaine M. Benes'))
+            sresult = self.app.send_task(self.mytask.name,
+                                         kwargs=dict(name='Elaine M. Benes'))
             self.assertNextTaskDataEqual(
             self.assertNextTaskDataEqual(
-                consumer, sresult, T1.name, name='Elaine M. Benes',
+                consumer, sresult, self.mytask.name, name='Elaine M. Benes',
             )
             )
 
 
             # With eta.
             # With eta.
-            presult2 = T1.apply_async(
+            presult2 = self.mytask.apply_async(
                 kwargs=dict(name='George Costanza'),
                 kwargs=dict(name='George Costanza'),
                 eta=self.now() + timedelta(days=1),
                 eta=self.now() + timedelta(days=1),
                 expires=self.now() + timedelta(days=2),
                 expires=self.now() + timedelta(days=2),
             )
             )
             self.assertNextTaskDataEqual(
             self.assertNextTaskDataEqual(
-                consumer, presult2, T1.name,
+                consumer, presult2, self.mytask.name,
                 name='George Costanza', test_eta=True, test_expires=True,
                 name='George Costanza', test_eta=True, test_expires=True,
             )
             )
 
 
             # With countdown.
             # With countdown.
-            presult2 = T1.apply_async(kwargs=dict(name='George Costanza'),
-                                      countdown=10, expires=12)
+            presult2 = self.mytask.apply_async(
+                kwargs=dict(name='George Costanza'), countdown=10, expires=12,
+            )
             self.assertNextTaskDataEqual(
             self.assertNextTaskDataEqual(
-                consumer, presult2, T1.name,
+                consumer, presult2, self.mytask.name,
                 name='George Costanza', test_eta=True, test_expires=True,
                 name='George Costanza', test_eta=True, test_expires=True,
             )
             )
 
 
             # Discarding all tasks.
             # Discarding all tasks.
             consumer.purge()
             consumer.purge()
-            T1.apply_async()
+            self.mytask.apply_async()
             self.assertEqual(consumer.purge(), 1)
             self.assertEqual(consumer.purge(), 1)
             self.assertIsNone(consumer.queues[0].get())
             self.assertIsNone(consumer.queues[0].get())
 
 
             self.assertFalse(presult.successful())
             self.assertFalse(presult.successful())
-            T1.backend.mark_as_done(presult.id, result=None)
+            self.mytask.backend.mark_as_done(presult.id, result=None)
             self.assertTrue(presult.successful())
             self.assertTrue(presult.successful())
 
 
     def test_repr_v2_compat(self):
     def test_repr_v2_compat(self):
-        task = type(self.create_task('c.unittest.v2c')._get_current_object())
-        task.__v2_compat__ = True
-        self.assertIn('v2 compatible', repr(task))
+        self.mytask.__v2_compat__ = True
+        self.assertIn('v2 compatible', repr(self.mytask))
 
 
     def test_apply_with_self(self):
     def test_apply_with_self(self):
 
 
@@ -354,46 +344,43 @@ class test_tasks(TasksCase):
         self.assertEqual(tawself(), 42)
         self.assertEqual(tawself(), 42)
 
 
     def test_context_get(self):
     def test_context_get(self):
-        task = self.create_task('c.unittest.t.c.g')
-        task.push_request()
+        self.mytask.push_request()
         try:
         try:
-            request = task.request
+            request = self.mytask.request
             request.foo = 32
             request.foo = 32
             self.assertEqual(request.get('foo'), 32)
             self.assertEqual(request.get('foo'), 32)
             self.assertEqual(request.get('bar', 36), 36)
             self.assertEqual(request.get('bar', 36), 36)
             request.clear()
             request.clear()
         finally:
         finally:
-            task.pop_request()
+            self.mytask.pop_request()
 
 
     def test_task_class_repr(self):
     def test_task_class_repr(self):
-        task = self.create_task('c.unittest.t.repr')
-        self.assertIn('class Task of', repr(task.app.Task))
-        prev, task.app.Task._app = task.app.Task._app, None
-        try:
-            self.assertIn('unbound', repr(task.app.Task, ))
-        finally:
-            task.app.Task._app = prev
+        self.assertIn('class Task of', repr(self.mytask.app.Task))
+        self.mytask.app.Task._app = None
+        self.assertIn('unbound', repr(self.mytask.app.Task, ))
 
 
     def test_bind_no_magic_kwargs(self):
     def test_bind_no_magic_kwargs(self):
-        task = self.create_task('c.unittest.t.magic_kwargs')
-        task.accept_magic_kwargs = None
-        task.bind(task.app)
+        self.mytask.accept_magic_kwargs = None
+        self.mytask.bind(self.mytask.app)
 
 
     def test_annotate(self):
     def test_annotate(self):
         with patch('celery.app.task.resolve_all_annotations') as anno:
         with patch('celery.app.task.resolve_all_annotations') as anno:
             anno.return_value = [{'FOO': 'BAR'}]
             anno.return_value = [{'FOO': 'BAR'}]
-            Task.annotate()
-            self.assertEqual(Task.FOO, 'BAR')
+
+            @self.app.task(shared=False)
+            def task():
+                pass
+            task.annotate()
+            self.assertEqual(task.FOO, 'BAR')
 
 
     def test_after_return(self):
     def test_after_return(self):
-        task = self.create_task('c.unittest.t.after_return')
-        task.push_request()
+        self.mytask.push_request()
         try:
         try:
-            task.request.chord = self.return_True_task.s()
-            task.after_return('SUCCESS', 1.0, 'foobar', (), {}, None)
-            task.request.clear()
+            self.mytask.request.chord = self.mytask.s()
+            self.mytask.after_return('SUCCESS', 1.0, 'foobar', (), {}, None)
+            self.mytask.request.clear()
         finally:
         finally:
-            task.pop_request()
+            self.mytask.pop_request()
 
 
     def test_send_task_sent_event(self):
     def test_send_task_sent_event(self):
         with self.app.connection() as conn:
         with self.app.connection() as conn:
@@ -471,756 +458,3 @@ class test_apply_task(TasksCase):
         self.assertTrue(f.traceback)
         self.assertTrue(f.traceback)
         with self.assertRaises(KeyError):
         with self.assertRaises(KeyError):
             f.get()
             f.get()
-
-
-@periodic_task(run_every=timedelta(hours=1))
-def my_periodic():
-    pass
-
-
-class test_periodic_tasks(AppCase):
-
-    def now(self):
-        return self.app.now()
-
-    def test_must_have_run_every(self):
-        with self.assertRaises(NotImplementedError):
-            type('Foo', (PeriodicTask, ), {'__module__': __name__})
-
-    def test_remaining_estimate(self):
-        s = my_periodic.run_every
-        self.assertIsInstance(
-            s.remaining_estimate(s.maybe_make_aware(self.now())),
-            timedelta)
-
-    def test_is_due_not_due(self):
-        due, remaining = my_periodic.run_every.is_due(self.now())
-        self.assertFalse(due)
-        # This assertion may fail if executed in the
-        # first minute of an hour, thus 59 instead of 60
-        self.assertGreater(remaining, 59)
-
-    def test_is_due(self):
-        p = my_periodic
-        due, remaining = p.run_every.is_due(
-            self.now() - p.run_every.run_every,
-        )
-        self.assertTrue(due)
-        self.assertEqual(remaining,
-                         timedelta_seconds(p.run_every.run_every))
-
-    def test_schedule_repr(self):
-        p = my_periodic
-        self.assertTrue(repr(p.run_every))
-
-
-@periodic_task(run_every=crontab())
-def every_minute():
-    pass
-
-
-@periodic_task(run_every=crontab(minute='*/15'))
-def quarterly():
-    pass
-
-
-@periodic_task(run_every=crontab(minute=30))
-def hourly():
-    pass
-
-
-@periodic_task(run_every=crontab(hour=7, minute=30))
-def daily():
-    pass
-
-
-@periodic_task(run_every=crontab(hour=7, minute=30,
-                                 day_of_week='thursday'))
-def weekly():
-    pass
-
-
-@periodic_task(run_every=crontab(hour=7, minute=30,
-                                 day_of_week='thursday',
-                                 day_of_month='8-14'))
-def monthly():
-    pass
-
-
-@periodic_task(run_every=crontab(hour=22,
-                                 day_of_week='*',
-                                 month_of_year='2',
-                                 day_of_month='26,27,28'))
-def monthly_moy():
-    pass
-
-
-@periodic_task(run_every=crontab(hour=7, minute=30,
-                                 day_of_week='thursday',
-                                 day_of_month='8-14',
-                                 month_of_year=3))
-def yearly():
-    pass
-
-
-def patch_crontab_nowfun(cls, retval):
-
-    def create_patcher(fun):
-
-        @wraps(fun)
-        def __inner(*args, **kwargs):
-            prev_nowfun = cls.run_every.nowfun
-            cls.run_every.nowfun = lambda: retval
-            try:
-                return fun(*args, **kwargs)
-            finally:
-                cls.run_every.nowfun = prev_nowfun
-
-        return __inner
-
-    return create_patcher
-
-
-class test_crontab_parser(AppCase):
-
-    def test_crontab_reduce(self):
-        self.assertTrue(loads(dumps(crontab('*'))))
-
-    def test_range_steps_not_enough(self):
-        with self.assertRaises(crontab_parser.ParseException):
-            crontab_parser(24)._range_steps([1])
-
-    def test_parse_star(self):
-        self.assertEqual(crontab_parser(24).parse('*'), set(range(24)))
-        self.assertEqual(crontab_parser(60).parse('*'), set(range(60)))
-        self.assertEqual(crontab_parser(7).parse('*'), set(range(7)))
-        self.assertEqual(crontab_parser(31, 1).parse('*'),
-                         set(range(1, 31 + 1)))
-        self.assertEqual(crontab_parser(12, 1).parse('*'),
-                         set(range(1, 12 + 1)))
-
-    def test_parse_range(self):
-        self.assertEqual(crontab_parser(60).parse('1-10'),
-                         set(range(1, 10 + 1)))
-        self.assertEqual(crontab_parser(24).parse('0-20'),
-                         set(range(0, 20 + 1)))
-        self.assertEqual(crontab_parser().parse('2-10'),
-                         set(range(2, 10 + 1)))
-        self.assertEqual(crontab_parser(60, 1).parse('1-10'),
-                         set(range(1, 10 + 1)))
-
-    def test_parse_range_wraps(self):
-        self.assertEqual(crontab_parser(12).parse('11-1'),
-                         set([11, 0, 1]))
-        self.assertEqual(crontab_parser(60, 1).parse('2-1'),
-                         set(range(1, 60 + 1)))
-
-    def test_parse_groups(self):
-        self.assertEqual(crontab_parser().parse('1,2,3,4'),
-                         set([1, 2, 3, 4]))
-        self.assertEqual(crontab_parser().parse('0,15,30,45'),
-                         set([0, 15, 30, 45]))
-        self.assertEqual(crontab_parser(min_=1).parse('1,2,3,4'),
-                         set([1, 2, 3, 4]))
-
-    def test_parse_steps(self):
-        self.assertEqual(crontab_parser(8).parse('*/2'),
-                         set([0, 2, 4, 6]))
-        self.assertEqual(crontab_parser().parse('*/2'),
-                         set(i * 2 for i in range(30)))
-        self.assertEqual(crontab_parser().parse('*/3'),
-                         set(i * 3 for i in range(20)))
-        self.assertEqual(crontab_parser(8, 1).parse('*/2'),
-                         set([1, 3, 5, 7]))
-        self.assertEqual(crontab_parser(min_=1).parse('*/2'),
-                         set(i * 2 + 1 for i in range(30)))
-        self.assertEqual(crontab_parser(min_=1).parse('*/3'),
-                         set(i * 3 + 1 for i in range(20)))
-
-    def test_parse_composite(self):
-        self.assertEqual(crontab_parser(8).parse('*/2'), set([0, 2, 4, 6]))
-        self.assertEqual(crontab_parser().parse('2-9/5'), set([2, 7]))
-        self.assertEqual(crontab_parser().parse('2-10/5'), set([2, 7]))
-        self.assertEqual(
-            crontab_parser(min_=1).parse('55-5/3'),
-            set([55, 58, 1, 4]),
-        )
-        self.assertEqual(crontab_parser().parse('2-11/5,3'), set([2, 3, 7]))
-        self.assertEqual(
-            crontab_parser().parse('2-4/3,*/5,0-21/4'),
-            set([0, 2, 4, 5, 8, 10, 12, 15, 16,
-                 20, 25, 30, 35, 40, 45, 50, 55]),
-        )
-        self.assertEqual(
-            crontab_parser().parse('1-9/2'),
-            set([1, 3, 5, 7, 9]),
-        )
-        self.assertEqual(crontab_parser(8, 1).parse('*/2'), set([1, 3, 5, 7]))
-        self.assertEqual(crontab_parser(min_=1).parse('2-9/5'), set([2, 7]))
-        self.assertEqual(crontab_parser(min_=1).parse('2-10/5'), set([2, 7]))
-        self.assertEqual(
-            crontab_parser(min_=1).parse('2-11/5,3'),
-            set([2, 3, 7]),
-        )
-        self.assertEqual(
-            crontab_parser(min_=1).parse('2-4/3,*/5,1-21/4'),
-            set([1, 2, 5, 6, 9, 11, 13, 16, 17,
-                 21, 26, 31, 36, 41, 46, 51, 56]),
-        )
-        self.assertEqual(
-            crontab_parser(min_=1).parse('1-9/2'),
-            set([1, 3, 5, 7, 9]),
-        )
-
-    def test_parse_errors_on_empty_string(self):
-        with self.assertRaises(ParseException):
-            crontab_parser(60).parse('')
-
-    def test_parse_errors_on_empty_group(self):
-        with self.assertRaises(ParseException):
-            crontab_parser(60).parse('1,,2')
-
-    def test_parse_errors_on_empty_steps(self):
-        with self.assertRaises(ParseException):
-            crontab_parser(60).parse('*/')
-
-    def test_parse_errors_on_negative_number(self):
-        with self.assertRaises(ParseException):
-            crontab_parser(60).parse('-20')
-
-    def test_parse_errors_on_lt_min(self):
-        crontab_parser(min_=1).parse('1')
-        with self.assertRaises(ValueError):
-            crontab_parser(12, 1).parse('0')
-        with self.assertRaises(ValueError):
-            crontab_parser(24, 1).parse('12-0')
-
-    def test_parse_errors_on_gt_max(self):
-        crontab_parser(1).parse('0')
-        with self.assertRaises(ValueError):
-            crontab_parser(1).parse('1')
-        with self.assertRaises(ValueError):
-            crontab_parser(60).parse('61-0')
-
-    def test_expand_cronspec_eats_iterables(self):
-        self.assertEqual(crontab._expand_cronspec(iter([1, 2, 3]), 100),
-                         set([1, 2, 3]))
-        self.assertEqual(crontab._expand_cronspec(iter([1, 2, 3]), 100, 1),
-                         set([1, 2, 3]))
-
-    def test_expand_cronspec_invalid_type(self):
-        with self.assertRaises(TypeError):
-            crontab._expand_cronspec(object(), 100)
-
-    def test_repr(self):
-        self.assertIn('*', repr(crontab('*')))
-
-    def test_eq(self):
-        self.assertEqual(crontab(day_of_week='1, 2'),
-                         crontab(day_of_week='1-2'))
-        self.assertEqual(crontab(day_of_month='1, 16, 31'),
-                         crontab(day_of_month='*/15'))
-        self.assertEqual(crontab(minute='1', hour='2', day_of_week='5',
-                                 day_of_month='10', month_of_year='5'),
-                         crontab(minute='1', hour='2', day_of_week='5',
-                                 day_of_month='10', month_of_year='5'))
-        self.assertNotEqual(crontab(minute='1'), crontab(minute='2'))
-        self.assertNotEqual(crontab(month_of_year='1'),
-                            crontab(month_of_year='2'))
-        self.assertFalse(object() == crontab(minute='1'))
-        self.assertFalse(crontab(minute='1') == object())
-
-
-class test_crontab_remaining_estimate(AppCase):
-
-    def next_ocurrance(self, crontab, now):
-        crontab.nowfun = lambda: now
-        return now + crontab.remaining_estimate(now)
-
-    def test_next_minute(self):
-        next = self.next_ocurrance(crontab(),
-                                   datetime(2010, 9, 11, 14, 30, 15))
-        self.assertEqual(next, datetime(2010, 9, 11, 14, 31))
-
-    def test_not_next_minute(self):
-        next = self.next_ocurrance(crontab(),
-                                   datetime(2010, 9, 11, 14, 59, 15))
-        self.assertEqual(next, datetime(2010, 9, 11, 15, 0))
-
-    def test_this_hour(self):
-        next = self.next_ocurrance(crontab(minute=[5, 42]),
-                                   datetime(2010, 9, 11, 14, 30, 15))
-        self.assertEqual(next, datetime(2010, 9, 11, 14, 42))
-
-    def test_not_this_hour(self):
-        next = self.next_ocurrance(crontab(minute=[5, 10, 15]),
-                                   datetime(2010, 9, 11, 14, 30, 15))
-        self.assertEqual(next, datetime(2010, 9, 11, 15, 5))
-
-    def test_today(self):
-        next = self.next_ocurrance(crontab(minute=[5, 42], hour=[12, 17]),
-                                   datetime(2010, 9, 11, 14, 30, 15))
-        self.assertEqual(next, datetime(2010, 9, 11, 17, 5))
-
-    def test_not_today(self):
-        next = self.next_ocurrance(crontab(minute=[5, 42], hour=[12]),
-                                   datetime(2010, 9, 11, 14, 30, 15))
-        self.assertEqual(next, datetime(2010, 9, 12, 12, 5))
-
-    def test_weekday(self):
-        next = self.next_ocurrance(crontab(minute=30,
-                                           hour=14,
-                                           day_of_week='sat'),
-                                   datetime(2010, 9, 11, 14, 30, 15))
-        self.assertEqual(next, datetime(2010, 9, 18, 14, 30))
-
-    def test_not_weekday(self):
-        next = self.next_ocurrance(crontab(minute=[5, 42],
-                                           day_of_week='mon-fri'),
-                                   datetime(2010, 9, 11, 14, 30, 15))
-        self.assertEqual(next, datetime(2010, 9, 13, 0, 5))
-
-    def test_monthday(self):
-        next = self.next_ocurrance(crontab(minute=30,
-                                           hour=14,
-                                           day_of_month=18),
-                                   datetime(2010, 9, 11, 14, 30, 15))
-        self.assertEqual(next, datetime(2010, 9, 18, 14, 30))
-
-    def test_not_monthday(self):
-        next = self.next_ocurrance(crontab(minute=[5, 42],
-                                           day_of_month=29),
-                                   datetime(2010, 1, 22, 14, 30, 15))
-        self.assertEqual(next, datetime(2010, 1, 29, 0, 5))
-
-    def test_weekday_monthday(self):
-        next = self.next_ocurrance(crontab(minute=30,
-                                           hour=14,
-                                           day_of_week='mon',
-                                           day_of_month=18),
-                                   datetime(2010, 1, 18, 14, 30, 15))
-        self.assertEqual(next, datetime(2010, 10, 18, 14, 30))
-
-    def test_monthday_not_weekday(self):
-        next = self.next_ocurrance(crontab(minute=[5, 42],
-                                           day_of_week='sat',
-                                           day_of_month=29),
-                                   datetime(2010, 1, 29, 0, 5, 15))
-        self.assertEqual(next, datetime(2010, 5, 29, 0, 5))
-
-    def test_weekday_not_monthday(self):
-        next = self.next_ocurrance(crontab(minute=[5, 42],
-                                           day_of_week='mon',
-                                           day_of_month=18),
-                                   datetime(2010, 1, 11, 0, 5, 15))
-        self.assertEqual(next, datetime(2010, 1, 18, 0, 5))
-
-    def test_not_weekday_not_monthday(self):
-        next = self.next_ocurrance(crontab(minute=[5, 42],
-                                           day_of_week='mon',
-                                           day_of_month=18),
-                                   datetime(2010, 1, 10, 0, 5, 15))
-        self.assertEqual(next, datetime(2010, 1, 18, 0, 5))
-
-    def test_leapday(self):
-        next = self.next_ocurrance(crontab(minute=30,
-                                           hour=14,
-                                           day_of_month=29),
-                                   datetime(2012, 1, 29, 14, 30, 15))
-        self.assertEqual(next, datetime(2012, 2, 29, 14, 30))
-
-    def test_not_leapday(self):
-        next = self.next_ocurrance(crontab(minute=30,
-                                           hour=14,
-                                           day_of_month=29),
-                                   datetime(2010, 1, 29, 14, 30, 15))
-        self.assertEqual(next, datetime(2010, 3, 29, 14, 30))
-
-    def test_weekmonthdayyear(self):
-        next = self.next_ocurrance(crontab(minute=30,
-                                           hour=14,
-                                           day_of_week='fri',
-                                           day_of_month=29,
-                                           month_of_year=1),
-                                   datetime(2010, 1, 22, 14, 30, 15))
-        self.assertEqual(next, datetime(2010, 1, 29, 14, 30))
-
-    def test_monthdayyear_not_week(self):
-        next = self.next_ocurrance(crontab(minute=[5, 42],
-                                           day_of_week='wed,thu',
-                                           day_of_month=29,
-                                           month_of_year='1,4,7'),
-                                   datetime(2010, 1, 29, 14, 30, 15))
-        self.assertEqual(next, datetime(2010, 4, 29, 0, 5))
-
-    def test_weekdaymonthyear_not_monthday(self):
-        next = self.next_ocurrance(crontab(minute=30,
-                                           hour=14,
-                                           day_of_week='fri',
-                                           day_of_month=29,
-                                           month_of_year='1-10'),
-                                   datetime(2010, 1, 29, 14, 30, 15))
-        self.assertEqual(next, datetime(2010, 10, 29, 14, 30))
-
-    def test_weekmonthday_not_monthyear(self):
-        next = self.next_ocurrance(crontab(minute=[5, 42],
-                                           day_of_week='fri',
-                                           day_of_month=29,
-                                           month_of_year='2-10'),
-                                   datetime(2010, 1, 29, 14, 30, 15))
-        self.assertEqual(next, datetime(2010, 10, 29, 0, 5))
-
-    def test_weekday_not_monthdayyear(self):
-        next = self.next_ocurrance(crontab(minute=[5, 42],
-                                           day_of_week='mon',
-                                           day_of_month=18,
-                                           month_of_year='2-10'),
-                                   datetime(2010, 1, 11, 0, 5, 15))
-        self.assertEqual(next, datetime(2010, 10, 18, 0, 5))
-
-    def test_monthday_not_weekdaymonthyear(self):
-        next = self.next_ocurrance(crontab(minute=[5, 42],
-                                           day_of_week='mon',
-                                           day_of_month=29,
-                                           month_of_year='2-4'),
-                                   datetime(2010, 1, 29, 0, 5, 15))
-        self.assertEqual(next, datetime(2010, 3, 29, 0, 5))
-
-    def test_monthyear_not_weekmonthday(self):
-        next = self.next_ocurrance(crontab(minute=[5, 42],
-                                           day_of_week='mon',
-                                           day_of_month=29,
-                                           month_of_year='2-4'),
-                                   datetime(2010, 2, 28, 0, 5, 15))
-        self.assertEqual(next, datetime(2010, 3, 29, 0, 5))
-
-    def test_not_weekmonthdayyear(self):
-        next = self.next_ocurrance(crontab(minute=[5, 42],
-                                           day_of_week='fri,sat',
-                                           day_of_month=29,
-                                           month_of_year='2-10'),
-                                   datetime(2010, 1, 28, 14, 30, 15))
-        self.assertEqual(next, datetime(2010, 5, 29, 0, 5))
-
-
-class test_crontab_is_due(AppCase):
-
-    def getnow(self):
-        return self.app.now()
-
-    def setup(self):
-        self.now = self.getnow()
-        self.next_minute = 60 - self.now.second - 1e-6 * self.now.microsecond
-
-    def test_default_crontab_spec(self):
-        c = crontab()
-        self.assertEqual(c.minute, set(range(60)))
-        self.assertEqual(c.hour, set(range(24)))
-        self.assertEqual(c.day_of_week, set(range(7)))
-        self.assertEqual(c.day_of_month, set(range(1, 32)))
-        self.assertEqual(c.month_of_year, set(range(1, 13)))
-
-    def test_simple_crontab_spec(self):
-        c = crontab(minute=30)
-        self.assertEqual(c.minute, set([30]))
-        self.assertEqual(c.hour, set(range(24)))
-        self.assertEqual(c.day_of_week, set(range(7)))
-        self.assertEqual(c.day_of_month, set(range(1, 32)))
-        self.assertEqual(c.month_of_year, set(range(1, 13)))
-
-    def test_crontab_spec_minute_formats(self):
-        c = crontab(minute=30)
-        self.assertEqual(c.minute, set([30]))
-        c = crontab(minute='30')
-        self.assertEqual(c.minute, set([30]))
-        c = crontab(minute=(30, 40, 50))
-        self.assertEqual(c.minute, set([30, 40, 50]))
-        c = crontab(minute=set([30, 40, 50]))
-        self.assertEqual(c.minute, set([30, 40, 50]))
-
-    def test_crontab_spec_invalid_minute(self):
-        with self.assertRaises(ValueError):
-            crontab(minute=60)
-        with self.assertRaises(ValueError):
-            crontab(minute='0-100')
-
-    def test_crontab_spec_hour_formats(self):
-        c = crontab(hour=6)
-        self.assertEqual(c.hour, set([6]))
-        c = crontab(hour='5')
-        self.assertEqual(c.hour, set([5]))
-        c = crontab(hour=(4, 8, 12))
-        self.assertEqual(c.hour, set([4, 8, 12]))
-
-    def test_crontab_spec_invalid_hour(self):
-        with self.assertRaises(ValueError):
-            crontab(hour=24)
-        with self.assertRaises(ValueError):
-            crontab(hour='0-30')
-
-    def test_crontab_spec_dow_formats(self):
-        c = crontab(day_of_week=5)
-        self.assertEqual(c.day_of_week, set([5]))
-        c = crontab(day_of_week='5')
-        self.assertEqual(c.day_of_week, set([5]))
-        c = crontab(day_of_week='fri')
-        self.assertEqual(c.day_of_week, set([5]))
-        c = crontab(day_of_week='tuesday,sunday,fri')
-        self.assertEqual(c.day_of_week, set([0, 2, 5]))
-        c = crontab(day_of_week='mon-fri')
-        self.assertEqual(c.day_of_week, set([1, 2, 3, 4, 5]))
-        c = crontab(day_of_week='*/2')
-        self.assertEqual(c.day_of_week, set([0, 2, 4, 6]))
-
-    def test_crontab_spec_invalid_dow(self):
-        with self.assertRaises(ValueError):
-            crontab(day_of_week='fooday-barday')
-        with self.assertRaises(ValueError):
-            crontab(day_of_week='1,4,foo')
-        with self.assertRaises(ValueError):
-            crontab(day_of_week='7')
-        with self.assertRaises(ValueError):
-            crontab(day_of_week='12')
-
-    def test_crontab_spec_dom_formats(self):
-        c = crontab(day_of_month=5)
-        self.assertEqual(c.day_of_month, set([5]))
-        c = crontab(day_of_month='5')
-        self.assertEqual(c.day_of_month, set([5]))
-        c = crontab(day_of_month='2,4,6')
-        self.assertEqual(c.day_of_month, set([2, 4, 6]))
-        c = crontab(day_of_month='*/5')
-        self.assertEqual(c.day_of_month, set([1, 6, 11, 16, 21, 26, 31]))
-
-    def test_crontab_spec_invalid_dom(self):
-        with self.assertRaises(ValueError):
-            crontab(day_of_month=0)
-        with self.assertRaises(ValueError):
-            crontab(day_of_month='0-10')
-        with self.assertRaises(ValueError):
-            crontab(day_of_month=32)
-        with self.assertRaises(ValueError):
-            crontab(day_of_month='31,32')
-
-    def test_crontab_spec_moy_formats(self):
-        c = crontab(month_of_year=1)
-        self.assertEqual(c.month_of_year, set([1]))
-        c = crontab(month_of_year='1')
-        self.assertEqual(c.month_of_year, set([1]))
-        c = crontab(month_of_year='2,4,6')
-        self.assertEqual(c.month_of_year, set([2, 4, 6]))
-        c = crontab(month_of_year='*/2')
-        self.assertEqual(c.month_of_year, set([1, 3, 5, 7, 9, 11]))
-        c = crontab(month_of_year='2-12/2')
-        self.assertEqual(c.month_of_year, set([2, 4, 6, 8, 10, 12]))
-
-    def test_crontab_spec_invalid_moy(self):
-        with self.assertRaises(ValueError):
-            crontab(month_of_year=0)
-        with self.assertRaises(ValueError):
-            crontab(month_of_year='0-5')
-        with self.assertRaises(ValueError):
-            crontab(month_of_year=13)
-        with self.assertRaises(ValueError):
-            crontab(month_of_year='12,13')
-
-    def seconds_almost_equal(self, a, b, precision):
-        for index, skew in enumerate((+0.1, 0, -0.1)):
-            try:
-                self.assertAlmostEqual(a, b + skew, precision)
-            except AssertionError:
-                if index + 1 >= 3:
-                    raise
-            else:
-                break
-
-    def assertRelativedelta(self, due, last_ran):
-        try:
-            from dateutil.relativedelta import relativedelta
-        except ImportError:
-            return
-        l1, d1, n1 = due.run_every.remaining_delta(last_ran)
-        l2, d2, n2 = due.run_every.remaining_delta(last_ran,
-                                                   ffwd=relativedelta)
-        if not isinstance(d1, relativedelta):
-            self.assertEqual(l1, l2)
-            for field, value in items(d1._fields()):
-                self.assertEqual(getattr(d1, field), value)
-            self.assertFalse(d2.years)
-            self.assertFalse(d2.months)
-            self.assertFalse(d2.days)
-            self.assertFalse(d2.leapdays)
-            self.assertFalse(d2.hours)
-            self.assertFalse(d2.minutes)
-            self.assertFalse(d2.seconds)
-            self.assertFalse(d2.microseconds)
-
-    def test_every_minute_execution_is_due(self):
-        last_ran = self.now - timedelta(seconds=61)
-        due, remaining = every_minute.run_every.is_due(last_ran)
-        self.assertRelativedelta(every_minute, last_ran)
-        self.assertTrue(due)
-        self.seconds_almost_equal(remaining, self.next_minute, 1)
-
-    def test_every_minute_execution_is_not_due(self):
-        last_ran = self.now - timedelta(seconds=self.now.second)
-        due, remaining = every_minute.run_every.is_due(last_ran)
-        self.assertFalse(due)
-        self.seconds_almost_equal(remaining, self.next_minute, 1)
-
-    # 29th of May 2010 is a saturday
-    @patch_crontab_nowfun(hourly, datetime(2010, 5, 29, 10, 30))
-    def test_execution_is_due_on_saturday(self):
-        last_ran = self.now - timedelta(seconds=61)
-        due, remaining = every_minute.run_every.is_due(last_ran)
-        self.assertTrue(due)
-        self.seconds_almost_equal(remaining, self.next_minute, 1)
-
-    # 30th of May 2010 is a sunday
-    @patch_crontab_nowfun(hourly, datetime(2010, 5, 30, 10, 30))
-    def test_execution_is_due_on_sunday(self):
-        last_ran = self.now - timedelta(seconds=61)
-        due, remaining = every_minute.run_every.is_due(last_ran)
-        self.assertTrue(due)
-        self.seconds_almost_equal(remaining, self.next_minute, 1)
-
-    # 31st of May 2010 is a monday
-    @patch_crontab_nowfun(hourly, datetime(2010, 5, 31, 10, 30))
-    def test_execution_is_due_on_monday(self):
-        last_ran = self.now - timedelta(seconds=61)
-        due, remaining = every_minute.run_every.is_due(last_ran)
-        self.assertTrue(due)
-        self.seconds_almost_equal(remaining, self.next_minute, 1)
-
-    @patch_crontab_nowfun(hourly, datetime(2010, 5, 10, 10, 30))
-    def test_every_hour_execution_is_due(self):
-        due, remaining = hourly.run_every.is_due(
-            datetime(2010, 5, 10, 6, 30))
-        self.assertTrue(due)
-        self.assertEqual(remaining, 60 * 60)
-
-    @patch_crontab_nowfun(hourly, datetime(2010, 5, 10, 10, 29))
-    def test_every_hour_execution_is_not_due(self):
-        due, remaining = hourly.run_every.is_due(
-            datetime(2010, 5, 10, 9, 30))
-        self.assertFalse(due)
-        self.assertEqual(remaining, 60)
-
-    @patch_crontab_nowfun(quarterly, datetime(2010, 5, 10, 10, 15))
-    def test_first_quarter_execution_is_due(self):
-        due, remaining = quarterly.run_every.is_due(
-            datetime(2010, 5, 10, 6, 30))
-        self.assertTrue(due)
-        self.assertEqual(remaining, 15 * 60)
-
-    @patch_crontab_nowfun(quarterly, datetime(2010, 5, 10, 10, 30))
-    def test_second_quarter_execution_is_due(self):
-        due, remaining = quarterly.run_every.is_due(
-            datetime(2010, 5, 10, 6, 30))
-        self.assertTrue(due)
-        self.assertEqual(remaining, 15 * 60)
-
-    @patch_crontab_nowfun(quarterly, datetime(2010, 5, 10, 10, 14))
-    def test_first_quarter_execution_is_not_due(self):
-        due, remaining = quarterly.run_every.is_due(
-            datetime(2010, 5, 10, 10, 0))
-        self.assertFalse(due)
-        self.assertEqual(remaining, 60)
-
-    @patch_crontab_nowfun(quarterly, datetime(2010, 5, 10, 10, 29))
-    def test_second_quarter_execution_is_not_due(self):
-        due, remaining = quarterly.run_every.is_due(
-            datetime(2010, 5, 10, 10, 15))
-        self.assertFalse(due)
-        self.assertEqual(remaining, 60)
-
-    @patch_crontab_nowfun(daily, datetime(2010, 5, 10, 7, 30))
-    def test_daily_execution_is_due(self):
-        due, remaining = daily.run_every.is_due(
-            datetime(2010, 5, 9, 7, 30))
-        self.assertTrue(due)
-        self.assertEqual(remaining, 24 * 60 * 60)
-
-    @patch_crontab_nowfun(daily, datetime(2010, 5, 10, 10, 30))
-    def test_daily_execution_is_not_due(self):
-        due, remaining = daily.run_every.is_due(
-            datetime(2010, 5, 10, 7, 30))
-        self.assertFalse(due)
-        self.assertEqual(remaining, 21 * 60 * 60)
-
-    @patch_crontab_nowfun(weekly, datetime(2010, 5, 6, 7, 30))
-    def test_weekly_execution_is_due(self):
-        due, remaining = weekly.run_every.is_due(
-            datetime(2010, 4, 30, 7, 30))
-        self.assertTrue(due)
-        self.assertEqual(remaining, 7 * 24 * 60 * 60)
-
-    @patch_crontab_nowfun(weekly, datetime(2010, 5, 7, 10, 30))
-    def test_weekly_execution_is_not_due(self):
-        due, remaining = weekly.run_every.is_due(
-            datetime(2010, 5, 6, 7, 30))
-        self.assertFalse(due)
-        self.assertEqual(remaining, 6 * 24 * 60 * 60 - 3 * 60 * 60)
-
-    @patch_crontab_nowfun(monthly, datetime(2010, 5, 13, 7, 30))
-    def test_monthly_execution_is_due(self):
-        due, remaining = monthly.run_every.is_due(
-            datetime(2010, 4, 8, 7, 30))
-        self.assertTrue(due)
-        self.assertEqual(remaining, 28 * 24 * 60 * 60)
-
-    @patch_crontab_nowfun(monthly, datetime(2010, 5, 9, 10, 30))
-    def test_monthly_execution_is_not_due(self):
-        due, remaining = monthly.run_every.is_due(
-            datetime(2010, 4, 8, 7, 30))
-        self.assertFalse(due)
-        self.assertEqual(remaining, 4 * 24 * 60 * 60 - 3 * 60 * 60)
-
-    @patch_crontab_nowfun(monthly_moy, datetime(2014, 2, 26, 22, 0))
-    def test_monthly_moy_execution_is_due(self):
-        due, remaining = monthly_moy.run_every.is_due(
-            datetime(2013, 7, 4, 10, 0))
-        self.assertTrue(due)
-        self.assertEqual(remaining, 60.)
-
-    @patch_crontab_nowfun(monthly_moy, datetime(2013, 6, 28, 14, 30))
-    def test_monthly_moy_execution_is_not_due(self):
-        raise SkipTest('unstable test')
-        due, remaining = monthly_moy.run_every.is_due(
-            datetime(2013, 6, 28, 22, 14))
-        self.assertFalse(due)
-        attempt = (
-            time.mktime(datetime(2014, 2, 26, 22, 0).timetuple()) -
-            time.mktime(datetime(2013, 6, 28, 14, 30).timetuple()) -
-            60 * 60
-        )
-        self.assertEqual(remaining, attempt)
-
-    @patch_crontab_nowfun(monthly_moy, datetime(2014, 2, 26, 22, 0))
-    def test_monthly_moy_execution_is_due2(self):
-        due, remaining = monthly_moy.run_every.is_due(
-            datetime(2013, 2, 28, 10, 0))
-        self.assertTrue(due)
-        self.assertEqual(remaining, 60.)
-
-    @patch_crontab_nowfun(monthly_moy, datetime(2014, 2, 26, 21, 0))
-    def test_monthly_moy_execution_is_not_due2(self):
-        due, remaining = monthly_moy.run_every.is_due(
-            datetime(2013, 6, 28, 22, 14))
-        self.assertFalse(due)
-        attempt = 60 * 60
-        self.assertEqual(remaining, attempt)
-
-    @patch_crontab_nowfun(yearly, datetime(2010, 3, 11, 7, 30))
-    def test_yearly_execution_is_due(self):
-        due, remaining = yearly.run_every.is_due(
-            datetime(2009, 3, 12, 7, 30))
-        self.assertTrue(due)
-        self.assertEqual(remaining, 364 * 24 * 60 * 60)
-
-    @patch_crontab_nowfun(yearly, datetime(2010, 3, 7, 10, 30))
-    def test_yearly_execution_is_not_due(self):
-        due, remaining = yearly.run_every.is_due(
-            datetime(2009, 3, 12, 7, 30))
-        self.assertFalse(due)
-        self.assertEqual(remaining, 4 * 24 * 60 * 60 - 3 * 60 * 60)

+ 29 - 26
celery/tests/tasks/test_trace.py

@@ -16,60 +16,63 @@ from celery.app.trace import (
 from celery.tests.case import AppCase
 from celery.tests.case import AppCase
 
 
 
 
-def trace(task, args=(), kwargs={}, propagate=False, **opts):
+def trace(app, task, args=(), kwargs={}, propagate=False, **opts):
     return eager_trace_task(task, 'id-1', args, kwargs,
     return eager_trace_task(task, 'id-1', args, kwargs,
-                            propagate=propagate, **opts)
+                            propagate=propagate, app=app, **opts)
 
 
 
 
 class TraceCase(AppCase):
 class TraceCase(AppCase):
 
 
     def setup(self):
     def setup(self):
-        @self.app.task
+        @self.app.task(shared=False)
         def add(x, y):
         def add(x, y):
             return x + y
             return x + y
         self.add = add
         self.add = add
 
 
-        @self.app.task(ignore_result=True)
+        @self.app.task(shared=False, ignore_result=True)
         def add_cast(x, y):
         def add_cast(x, y):
             return x + y
             return x + y
         self.add_cast = add_cast
         self.add_cast = add_cast
 
 
-        @self.app.task
+        @self.app.task(shared=False)
         def raises(exc):
         def raises(exc):
             raise exc
             raise exc
         self.raises = raises
         self.raises = raises
 
 
+    def trace(self, *args, **kwargs):
+        return trace(self.app, *args, **kwargs)
+
 
 
 class test_trace(TraceCase):
 class test_trace(TraceCase):
 
 
     def test_trace_successful(self):
     def test_trace_successful(self):
-        retval, info = trace(self.add, (2, 2), {})
+        retval, info = self.trace(self.add, (2, 2), {})
         self.assertIsNone(info)
         self.assertIsNone(info)
         self.assertEqual(retval, 4)
         self.assertEqual(retval, 4)
 
 
     def test_trace_on_success(self):
     def test_trace_on_success(self):
 
 
-        @self.app.task(on_success=Mock())
+        @self.app.task(shared=False, on_success=Mock())
         def add_with_success(x, y):
         def add_with_success(x, y):
             return x + y
             return x + y
 
 
-        trace(add_with_success, (2, 2), {})
+        self.trace(add_with_success, (2, 2), {})
         self.assertTrue(add_with_success.on_success.called)
         self.assertTrue(add_with_success.on_success.called)
 
 
     def test_trace_after_return(self):
     def test_trace_after_return(self):
 
 
-        @self.app.task(after_return=Mock())
+        @self.app.task(shared=False, after_return=Mock())
         def add_with_after_return(x, y):
         def add_with_after_return(x, y):
             return x + y
             return x + y
 
 
-        trace(add_with_after_return, (2, 2), {})
+        self.trace(add_with_after_return, (2, 2), {})
         self.assertTrue(add_with_after_return.after_return.called)
         self.assertTrue(add_with_after_return.after_return.called)
 
 
     def test_with_prerun_receivers(self):
     def test_with_prerun_receivers(self):
         on_prerun = Mock()
         on_prerun = Mock()
         signals.task_prerun.connect(on_prerun)
         signals.task_prerun.connect(on_prerun)
         try:
         try:
-            trace(self.add, (2, 2), {})
+            self.trace(self.add, (2, 2), {})
             self.assertTrue(on_prerun.called)
             self.assertTrue(on_prerun.called)
         finally:
         finally:
             signals.task_prerun.receivers[:] = []
             signals.task_prerun.receivers[:] = []
@@ -78,7 +81,7 @@ class test_trace(TraceCase):
         on_postrun = Mock()
         on_postrun = Mock()
         signals.task_postrun.connect(on_postrun)
         signals.task_postrun.connect(on_postrun)
         try:
         try:
-            trace(self.add, (2, 2), {})
+            self.trace(self.add, (2, 2), {})
             self.assertTrue(on_postrun.called)
             self.assertTrue(on_postrun.called)
         finally:
         finally:
             signals.task_postrun.receivers[:] = []
             signals.task_postrun.receivers[:] = []
@@ -87,62 +90,62 @@ class test_trace(TraceCase):
         on_success = Mock()
         on_success = Mock()
         signals.task_success.connect(on_success)
         signals.task_success.connect(on_success)
         try:
         try:
-            trace(self.add, (2, 2), {})
+            self.trace(self.add, (2, 2), {})
             self.assertTrue(on_success.called)
             self.assertTrue(on_success.called)
         finally:
         finally:
             signals.task_success.receivers[:] = []
             signals.task_success.receivers[:] = []
 
 
     def test_when_chord_part(self):
     def test_when_chord_part(self):
 
 
-        @self.app.task
+        @self.app.task(shared=False)
         def add(x, y):
         def add(x, y):
             return x + y
             return x + y
         add.backend = Mock()
         add.backend = Mock()
 
 
-        trace(add, (2, 2), {}, request={'chord': uuid()})
+        self.trace(add, (2, 2), {}, request={'chord': uuid()})
         add.backend.on_chord_part_return.assert_called_with(add)
         add.backend.on_chord_part_return.assert_called_with(add)
 
 
     def test_when_backend_cleanup_raises(self):
     def test_when_backend_cleanup_raises(self):
 
 
-        @self.app.task
+        @self.app.task(shared=False)
         def add(x, y):
         def add(x, y):
             return x + y
             return x + y
         add.backend = Mock(name='backend')
         add.backend = Mock(name='backend')
         add.backend.process_cleanup.side_effect = KeyError()
         add.backend.process_cleanup.side_effect = KeyError()
-        trace(add, (2, 2), {}, eager=False)
+        self.trace(add, (2, 2), {}, eager=False)
         add.backend.process_cleanup.assert_called_with()
         add.backend.process_cleanup.assert_called_with()
         add.backend.process_cleanup.side_effect = MemoryError()
         add.backend.process_cleanup.side_effect = MemoryError()
         with self.assertRaises(MemoryError):
         with self.assertRaises(MemoryError):
-            trace(add, (2, 2), {}, eager=False)
+            self.trace(add, (2, 2), {}, eager=False)
 
 
     def test_when_Ignore(self):
     def test_when_Ignore(self):
 
 
-        @self.app.task
+        @self.app.task(shared=False)
         def ignored():
         def ignored():
             raise Ignore()
             raise Ignore()
 
 
-        retval, info = trace(ignored, (), {})
+        retval, info = self.trace(ignored, (), {})
         self.assertEqual(info.state, states.IGNORED)
         self.assertEqual(info.state, states.IGNORED)
 
 
     def test_trace_SystemExit(self):
     def test_trace_SystemExit(self):
         with self.assertRaises(SystemExit):
         with self.assertRaises(SystemExit):
-            trace(self.raises, (SystemExit(), ), {})
+            self.trace(self.raises, (SystemExit(), ), {})
 
 
     def test_trace_RetryTaskError(self):
     def test_trace_RetryTaskError(self):
         exc = RetryTaskError('foo', 'bar')
         exc = RetryTaskError('foo', 'bar')
-        _, info = trace(self.raises, (exc, ), {})
+        _, info = self.trace(self.raises, (exc, ), {})
         self.assertEqual(info.state, states.RETRY)
         self.assertEqual(info.state, states.RETRY)
         self.assertIs(info.retval, exc)
         self.assertIs(info.retval, exc)
 
 
     def test_trace_exception(self):
     def test_trace_exception(self):
         exc = KeyError('foo')
         exc = KeyError('foo')
-        _, info = trace(self.raises, (exc, ), {})
+        _, info = self.trace(self.raises, (exc, ), {})
         self.assertEqual(info.state, states.FAILURE)
         self.assertEqual(info.state, states.FAILURE)
         self.assertIs(info.retval, exc)
         self.assertIs(info.retval, exc)
 
 
     def test_trace_exception_propagate(self):
     def test_trace_exception_propagate(self):
         with self.assertRaises(KeyError):
         with self.assertRaises(KeyError):
-            trace(self.raises, (KeyError('foo'), ), {}, propagate=True)
+            self.trace(self.raises, (KeyError('foo'), ), {}, propagate=True)
 
 
     @patch('celery.app.trace.build_tracer')
     @patch('celery.app.trace.build_tracer')
     @patch('celery.app.trace.report_internal_error')
     @patch('celery.app.trace.report_internal_error')
@@ -151,7 +154,7 @@ class test_trace(TraceCase):
         tracer.side_effect = KeyError('foo')
         tracer.side_effect = KeyError('foo')
         build_tracer.return_value = tracer
         build_tracer.return_value = tracer
 
 
-        @self.app.task
+        @self.app.task(shared=False)
         def xtask():
         def xtask():
             pass
             pass
 
 
@@ -180,7 +183,7 @@ class test_stackprotection(AppCase):
     def test_stackprotection(self):
     def test_stackprotection(self):
         setup_worker_optimizations(self.app)
         setup_worker_optimizations(self.app)
         try:
         try:
-            @self.app.task(bind=True)
+            @self.app.task(shared=False, bind=True)
             def foo(self, i):
             def foo(self, i):
                 if i:
                 if i:
                     return foo(0)
                     return foo(0)

+ 7 - 8
celery/tests/utils/test_dispatcher.py

@@ -55,7 +55,7 @@ class DispatcherTests(Case):
         # force cleanup just in case
         # force cleanup just in case
         signal.receivers = []
         signal.receivers = []
 
 
-    def testExact(self):
+    def test_exact(self):
         a_signal.connect(receiver_1_arg, sender=self)
         a_signal.connect(receiver_1_arg, sender=self)
         expected = [(receiver_1_arg, 'test')]
         expected = [(receiver_1_arg, 'test')]
         result = a_signal.send(sender=self, val='test')
         result = a_signal.send(sender=self, val='test')
@@ -63,7 +63,7 @@ class DispatcherTests(Case):
         a_signal.disconnect(receiver_1_arg, sender=self)
         a_signal.disconnect(receiver_1_arg, sender=self)
         self._testIsClean(a_signal)
         self._testIsClean(a_signal)
 
 
-    def testIgnoredSender(self):
+    def test_ignored_sender(self):
         a_signal.connect(receiver_1_arg)
         a_signal.connect(receiver_1_arg)
         expected = [(receiver_1_arg, 'test')]
         expected = [(receiver_1_arg, 'test')]
         result = a_signal.send(sender=self, val='test')
         result = a_signal.send(sender=self, val='test')
@@ -71,7 +71,7 @@ class DispatcherTests(Case):
         a_signal.disconnect(receiver_1_arg)
         a_signal.disconnect(receiver_1_arg)
         self._testIsClean(a_signal)
         self._testIsClean(a_signal)
 
 
-    def testGarbageCollected(self):
+    def test_garbage_collected(self):
         a = Callable()
         a = Callable()
         a_signal.connect(a.a, sender=self)
         a_signal.connect(a.a, sender=self)
         expected = []
         expected = []
@@ -81,7 +81,7 @@ class DispatcherTests(Case):
         self.assertEqual(result, expected)
         self.assertEqual(result, expected)
         self._testIsClean(a_signal)
         self._testIsClean(a_signal)
 
 
-    def testMultipleRegistration(self):
+    def test_multiple_registration(self):
         a = Callable()
         a = Callable()
         a_signal.connect(a)
         a_signal.connect(a)
         a_signal.connect(a)
         a_signal.connect(a)
@@ -97,7 +97,7 @@ class DispatcherTests(Case):
         garbage_collect()
         garbage_collect()
         self._testIsClean(a_signal)
         self._testIsClean(a_signal)
 
 
-    def testUidRegistration(self):
+    def test_uid_registration(self):
 
 
         def uid_based_receiver_1(**kwargs):
         def uid_based_receiver_1(**kwargs):
             pass
             pass
@@ -111,8 +111,7 @@ class DispatcherTests(Case):
         a_signal.disconnect(dispatch_uid='uid')
         a_signal.disconnect(dispatch_uid='uid')
         self._testIsClean(a_signal)
         self._testIsClean(a_signal)
 
 
-    def testRobust(self):
-        """Test the sendRobust function"""
+    def test_robust(self):
 
 
         def fails(val, **kwargs):
         def fails(val, **kwargs):
             raise ValueError('this')
             raise ValueError('this')
@@ -125,7 +124,7 @@ class DispatcherTests(Case):
         a_signal.disconnect(fails)
         a_signal.disconnect(fails)
         self._testIsClean(a_signal)
         self._testIsClean(a_signal)
 
 
-    def testDisconnection(self):
+    def test_disconnection(self):
         receiver_1 = Callable()
         receiver_1 = Callable()
         receiver_2 = Callable()
         receiver_2 = Callable()
         receiver_3 = Callable()
         receiver_3 = Callable()

+ 22 - 8
celery/tests/utils/test_saferef.py

@@ -46,18 +46,30 @@ class SaferefTests(Case):
         del self.ts
         del self.ts
         del self.ss
         del self.ss
 
 
-    def testIn(self):
-        """Test the "in" operator for safe references (cmp)"""
+    def test_in(self):
+        """test_in
+
+        Test the "in" operator for safe references (cmp)
+
+        """
         for t in self.ts[:50]:
         for t in self.ts[:50]:
             self.assertTrue(safe_ref(t.x) in self.ss)
             self.assertTrue(safe_ref(t.x) in self.ss)
 
 
-    def testValid(self):
-        """Test that the references are valid (return instance methods)"""
+    def test_valid(self):
+        """test_value
+
+        Test that the references are valid (return instance methods)
+
+        """
         for s in self.ss:
         for s in self.ss:
             self.assertTrue(s())
             self.assertTrue(s())
 
 
-    def testShortCircuit(self):
-        """Test that creation short-circuits to reuse existing references"""
+    def test_shortcircuit(self):
+        """test_shortcircuit
+
+        Test that creation short-circuits to reuse existing references
+
+        """
         sd = {}
         sd = {}
         for s in self.ss:
         for s in self.ss:
             sd[s] = 1
             sd[s] = 1
@@ -67,8 +79,10 @@ class SaferefTests(Case):
             else:
             else:
                 self.assertIn(safe_ref(t), sd)
                 self.assertIn(safe_ref(t), sd)
 
 
-    def testRepresentation(self):
-        """Test that the reference object's representation works
+    def test_representation(self):
+        """test_representation
+
+        Test that the reference object's representation works
 
 
         XXX Doesn't currently check the results, just that no error
         XXX Doesn't currently check the results, just that no error
             is raised
             is raised

+ 4 - 6
celery/tests/utils/test_text.py

@@ -1,6 +1,5 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
 
 
-from celery import Celery
 from celery.utils.text import (
 from celery.utils.text import (
     indent,
     indent,
     ensure_2lines,
     ensure_2lines,
@@ -9,7 +8,7 @@ from celery.utils.text import (
     abbrtask,
     abbrtask,
     pretty,
     pretty,
 )
 )
-from celery.tests.case import Case
+from celery.tests.case import AppCase, Case
 
 
 RANDTEXT = """\
 RANDTEXT = """\
 The quick brown
 The quick brown
@@ -43,15 +42,14 @@ QUEUE_FORMAT1 = '.> queue1           exchange=exchange1(type1) key=bind1'
 QUEUE_FORMAT2 = '.> queue2           exchange=exchange2(type2) key=bind2'
 QUEUE_FORMAT2 = '.> queue2           exchange=exchange2(type2) key=bind2'
 
 
 
 
-class test_Info(Case):
+class test_Info(AppCase):
 
 
     def test_textindent(self):
     def test_textindent(self):
         self.assertEqual(indent(RANDTEXT, 4), RANDTEXT_RES)
         self.assertEqual(indent(RANDTEXT, 4), RANDTEXT_RES)
 
 
     def test_format_queues(self):
     def test_format_queues(self):
-        celery = Celery(set_as_current=False)
-        celery.amqp.queues = celery.amqp.Queues(QUEUES)
-        self.assertEqual(sorted(celery.amqp.queues.format().split('\n')),
+        self.app.amqp.queues = self.app.amqp.Queues(QUEUES)
+        self.assertEqual(sorted(self.app.amqp.queues.format().split('\n')),
                          sorted([QUEUE_FORMAT1, QUEUE_FORMAT2]))
                          sorted([QUEUE_FORMAT1, QUEUE_FORMAT2]))
 
 
     def test_ensure_2lines(self):
     def test_ensure_2lines(self):

+ 6 - 6
celery/tests/worker/test_bootsteps.py

@@ -4,10 +4,10 @@ from mock import Mock, patch
 
 
 from celery import bootsteps
 from celery import bootsteps
 
 
-from celery.tests.case import AppCase, Case
+from celery.tests.case import AppCase
 
 
 
 
-class test_StepFormatter(Case):
+class test_StepFormatter(AppCase):
 
 
     def test_get_prefix(self):
     def test_get_prefix(self):
         f = bootsteps.StepFormatter()
         f = bootsteps.StepFormatter()
@@ -53,12 +53,12 @@ class test_StepFormatter(Case):
         })
         })
 
 
 
 
-class test_Step(Case):
+class test_Step(AppCase):
 
 
     class Def(bootsteps.StartStopStep):
     class Def(bootsteps.StartStopStep):
         name = 'test_Step.Def'
         name = 'test_Step.Def'
 
 
-    def setUp(self):
+    def setup(self):
         self.steps = []
         self.steps = []
 
 
     def test_blueprint_name(self, bp='test_blueprint_name'):
     def test_blueprint_name(self, bp='test_blueprint_name'):
@@ -151,12 +151,12 @@ class test_ConsumerStep(AppCase):
         step.start(self)
         step.start(self)
 
 
 
 
-class test_StartStopStep(Case):
+class test_StartStopStep(AppCase):
 
 
     class Def(bootsteps.StartStopStep):
     class Def(bootsteps.StartStopStep):
         name = 'test_StartStopStep.Def'
         name = 'test_StartStopStep.Def'
 
 
-    def setUp(self):
+    def setup(self):
         self.steps = []
         self.steps = []
 
 
     def test_start__stop(self):
     def test_start__stop(self):

+ 14 - 27
celery/tests/worker/test_consumer.py

@@ -59,26 +59,16 @@ class test_Consumer(AppCase):
     def test_sets_heartbeat(self):
     def test_sets_heartbeat(self):
         c = self.get_consumer(amqheartbeat=10)
         c = self.get_consumer(amqheartbeat=10)
         self.assertEqual(c.amqheartbeat, 10)
         self.assertEqual(c.amqheartbeat, 10)
-        prev, self.app.conf.BROKER_HEARTBEAT = (
-            self.app.conf.BROKER_HEARTBEAT, 20,
-        )
-        try:
-            c = self.get_consumer(amqheartbeat=None)
-            self.assertEqual(c.amqheartbeat, 20)
-        finally:
-            self.app.conf.BROKER_HEARTBEAT = prev
+        self.app.conf.BROKER_HEARTBEAT = 20
+        c = self.get_consumer(amqheartbeat=None)
+        self.assertEqual(c.amqheartbeat, 20)
 
 
     def test_gevent_bug_disables_connection_timeout(self):
     def test_gevent_bug_disables_connection_timeout(self):
         with patch('celery.worker.consumer._detect_environment') as de:
         with patch('celery.worker.consumer._detect_environment') as de:
             de.return_value = 'gevent'
             de.return_value = 'gevent'
-            prev, self.app.conf.BROKER_CONNECTION_TIMEOUT = (
-                self.app.conf.BROKER_CONNECTION_TIMEOUT, 33.33,
-            )
-            try:
-                self.get_consumer()
-                self.assertIsNone(self.app.conf.BROKER_CONNECTION_TIMEOUT)
-            finally:
-                self.app.conf.BROKER_CONNECTION_TIMEOUT = prev
+            self.app.conf.BROKER_CONNECTION_TIMEOUT = 33.33
+            self.get_consumer()
+            self.assertIsNone(self.app.conf.BROKER_CONNECTION_TIMEOUT)
 
 
     def test_limit_task(self):
     def test_limit_task(self):
         c = self.get_consumer()
         c = self.get_consumer()
@@ -163,17 +153,14 @@ class test_Consumer(AppCase):
             c.on_close()
             c.on_close()
 
 
     def test_connect_error_handler(self):
     def test_connect_error_handler(self):
-        _prev, self.app.connection = self.app.connection, Mock()
-        try:
-            conn = self.app.connection.return_value = Mock()
-            c = self.get_consumer()
-            self.assertTrue(c.connect())
-            self.assertTrue(conn.ensure_connection.called)
-            errback = conn.ensure_connection.call_args[0][0]
-            conn.alt = [(1, 2, 3)]
-            errback(Mock(), 0)
-        finally:
-            self.app.connection = _prev
+        self.app.connection = Mock()
+        conn = self.app.connection.return_value = Mock()
+        c = self.get_consumer()
+        self.assertTrue(c.connect())
+        self.assertTrue(conn.ensure_connection.called)
+        errback = conn.ensure_connection.call_args[0][0]
+        conn.alt = [(1, 2, 3)]
+        errback(Mock(), 0)
 
 
 
 
 class test_Heart(AppCase):
 class test_Heart(AppCase):

+ 59 - 76
celery/tests/worker/test_control.py

@@ -198,28 +198,24 @@ class test_ControlPanel(AppCase):
 
 
     def test_time_limit(self):
     def test_time_limit(self):
         panel = self.create_panel(consumer=Mock())
         panel = self.create_panel(consumer=Mock())
-        th, ts = self.mytask.time_limit, self.mytask.soft_time_limit
-        try:
-            r = panel.handle('time_limit', arguments=dict(
-                task_name=self.mytask.name, hard=30, soft=10))
-            self.assertEqual(
-                (self.mytask.time_limit, self.mytask.soft_time_limit),
-                (30, 10),
-            )
-            self.assertIn('ok', r)
-            r = panel.handle('time_limit', arguments=dict(
-                task_name=self.mytask.name, hard=None, soft=None))
-            self.assertEqual(
-                (self.mytask.time_limit, self.mytask.soft_time_limit),
-                (None, None),
-            )
-            self.assertIn('ok', r)
+        r = panel.handle('time_limit', arguments=dict(
+            task_name=self.mytask.name, hard=30, soft=10))
+        self.assertEqual(
+            (self.mytask.time_limit, self.mytask.soft_time_limit),
+            (30, 10),
+        )
+        self.assertIn('ok', r)
+        r = panel.handle('time_limit', arguments=dict(
+            task_name=self.mytask.name, hard=None, soft=None))
+        self.assertEqual(
+            (self.mytask.time_limit, self.mytask.soft_time_limit),
+            (None, None),
+        )
+        self.assertIn('ok', r)
 
 
-            r = panel.handle('time_limit', arguments=dict(
-                task_name='248e8afya9s8dh921eh928', hard=30))
-            self.assertIn('error', r)
-        finally:
-            self.time_limit, self.soft_time_limit = th, ts
+        r = panel.handle('time_limit', arguments=dict(
+            task_name='248e8afya9s8dh921eh928', hard=30))
+        self.assertIn('error', r)
 
 
     def test_active_queues(self):
     def test_active_queues(self):
         import kombu
         import kombu
@@ -381,19 +377,15 @@ class test_ControlPanel(AppCase):
         panel = self.create_panel(app=self.app, consumer=consumer)
         panel = self.create_panel(app=self.app, consumer=consumer)
 
 
         task = self.app.tasks[self.mytask.name]
         task = self.app.tasks[self.mytask.name]
-        old_rate_limit = task.rate_limit
-        try:
-            panel.handle('rate_limit', arguments=dict(task_name=task.name,
-                                                      rate_limit='100/m'))
-            self.assertEqual(task.rate_limit, '100/m')
-            self.assertTrue(consumer.reset)
-            consumer.reset = False
-            panel.handle('rate_limit', arguments=dict(task_name=task.name,
-                                                      rate_limit=0))
-            self.assertEqual(task.rate_limit, 0)
-            self.assertTrue(consumer.reset)
-        finally:
-            task.rate_limit = old_rate_limit
+        panel.handle('rate_limit', arguments=dict(task_name=task.name,
+                                                  rate_limit='100/m'))
+        self.assertEqual(task.rate_limit, '100/m')
+        self.assertTrue(consumer.reset)
+        consumer.reset = False
+        panel.handle('rate_limit', arguments=dict(task_name=task.name,
+                                                  rate_limit=0))
+        self.assertEqual(task.rate_limit, 0)
+        self.assertTrue(consumer.reset)
 
 
     def test_rate_limit_nonexistant_task(self):
     def test_rate_limit_nonexistant_task(self):
         self.panel.handle('rate_limit', arguments={
         self.panel.handle('rate_limit', arguments={
@@ -509,13 +501,10 @@ class test_ControlPanel(AppCase):
             panel.handle('pool_restart', {'reloader': _reload})
             panel.handle('pool_restart', {'reloader': _reload})
 
 
         self.app.conf.CELERYD_POOL_RESTARTS = True
         self.app.conf.CELERYD_POOL_RESTARTS = True
-        try:
-            panel.handle('pool_restart', {'reloader': _reload})
-            self.assertTrue(consumer.controller.pool.restart.called)
-            self.assertFalse(_reload.called)
-            self.assertFalse(_import.called)
-        finally:
-            self.app.conf.CELERYD_POOL_RESTARTS = False
+        panel.handle('pool_restart', {'reloader': _reload})
+        self.assertTrue(consumer.controller.pool.restart.called)
+        self.assertFalse(_reload.called)
+        self.assertFalse(_import.called)
 
 
     def test_pool_restart_import_modules(self):
     def test_pool_restart_import_modules(self):
         consumer = Consumer(self.app)
         consumer = Consumer(self.app)
@@ -527,18 +516,15 @@ class test_ControlPanel(AppCase):
         _reload = Mock()
         _reload = Mock()
 
 
         self.app.conf.CELERYD_POOL_RESTARTS = True
         self.app.conf.CELERYD_POOL_RESTARTS = True
-        try:
-            panel.handle('pool_restart', {'modules': ['foo', 'bar'],
-                                          'reloader': _reload})
-
-            self.assertTrue(consumer.controller.pool.restart.called)
-            self.assertFalse(_reload.called)
-            self.assertItemsEqual(
-                [call('bar'), call('foo')],
-                _import.call_args_list,
-            )
-        finally:
-            self.app.conf.CELERYD_POOL_RESTARTS = False
+        panel.handle('pool_restart', {'modules': ['foo', 'bar'],
+                                      'reloader': _reload})
+
+        self.assertTrue(consumer.controller.pool.restart.called)
+        self.assertFalse(_reload.called)
+        self.assertItemsEqual(
+            [call('bar'), call('foo')],
+            _import.call_args_list,
+        )
 
 
     def test_pool_restart_reload_modules(self):
     def test_pool_restart_reload_modules(self):
         consumer = Consumer(self.app)
         consumer = Consumer(self.app)
@@ -550,26 +536,23 @@ class test_ControlPanel(AppCase):
         _reload = Mock()
         _reload = Mock()
 
 
         self.app.conf.CELERYD_POOL_RESTARTS = True
         self.app.conf.CELERYD_POOL_RESTARTS = True
-        try:
-            with patch.dict(sys.modules, {'foo': None}):
-                panel.handle('pool_restart', {'modules': ['foo'],
-                                              'reload': False,
-                                              'reloader': _reload})
-
-                self.assertTrue(consumer.controller.pool.restart.called)
-                self.assertFalse(_reload.called)
-                self.assertFalse(_import.called)
-
-                _import.reset_mock()
-                _reload.reset_mock()
-                consumer.controller.pool.restart.reset_mock()
-
-                panel.handle('pool_restart', {'modules': ['foo'],
-                                              'reload': True,
-                                              'reloader': _reload})
-
-                self.assertTrue(consumer.controller.pool.restart.called)
-                self.assertTrue(_reload.called)
-                self.assertFalse(_import.called)
-        finally:
-            self.app.conf.CELERYD_POOL_RESTARTS = False
+        with patch.dict(sys.modules, {'foo': None}):
+            panel.handle('pool_restart', {'modules': ['foo'],
+                                          'reload': False,
+                                          'reloader': _reload})
+
+            self.assertTrue(consumer.controller.pool.restart.called)
+            self.assertFalse(_reload.called)
+            self.assertFalse(_import.called)
+
+            _import.reset_mock()
+            _reload.reset_mock()
+            consumer.controller.pool.restart.reset_mock()
+
+            panel.handle('pool_restart', {'modules': ['foo'],
+                                          'reload': True,
+                                          'reloader': _reload})
+
+            self.assertTrue(consumer.controller.pool.restart.called)
+            self.assertTrue(_reload.called)
+            self.assertFalse(_import.called)

+ 2 - 2
celery/tests/worker/test_heartbeat.py

@@ -1,7 +1,7 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
 
 
 from celery.worker.heartbeat import Heart
 from celery.worker.heartbeat import Heart
-from celery.tests.case import Case
+from celery.tests.case import AppCase
 
 
 
 
 class MockDispatcher(object):
 class MockDispatcher(object):
@@ -45,7 +45,7 @@ class MockTimer(object):
         entry.cancel()
         entry.cancel()
 
 
 
 
-class test_Heart(Case):
+class test_Heart(AppCase):
 
 
     def test_start_stop(self):
     def test_start_stop(self):
         timer = MockTimer()
         timer = MockTimer()

+ 1 - 1
celery/tests/worker/test_loops.py

@@ -100,7 +100,7 @@ class test_asynloop(AppCase):
 
 
     def setup(self):
     def setup(self):
 
 
-        @self.app.task()
+        @self.app.task(shared=False)
         def add(x, y):
         def add(x, y):
             return x + y
             return x + y
         self.add = add
         self.add = add

+ 27 - 30
celery/tests/worker/test_request.py

@@ -84,7 +84,7 @@ def jail(app, task_id, name, args, kwargs):
     task = app.tasks[name]
     task = app.tasks[name]
     task.__trace__ = None  # rebuild
     task.__trace__ = None  # rebuild
     return trace_task(
     return trace_task(
-        task, task_id, args, kwargs, request=request, eager=False,
+        task, task_id, args, kwargs, request=request, eager=False, app=app,
     )
     )
 
 
 
 
@@ -161,25 +161,23 @@ class test_trace_task(AppCase):
         self.assertEqual(ret, 4)
         self.assertEqual(ret, 4)
 
 
     def test_marked_as_started(self):
     def test_marked_as_started(self):
+        _started = []
 
 
-        class Backend(self.mytask.backend.__class__):
-            _started = []
-
-            def store_result(self, tid, meta, state):
-                if state == states.STARTED:
-                    self._started.append(tid)
-
-        self.mytask.backend = Backend(self.app)
+        def store_result(tid, meta, state):
+            if state == states.STARTED:
+                _started.append(tid)
+        self.mytask.backend.store_result = Mock(name='store_result')
+        self.mytask.backend.store_result.side_effect = store_result
         self.mytask.track_started = True
         self.mytask.track_started = True
 
 
         tid = uuid()
         tid = uuid()
         jail(self.app, tid, self.mytask.name, [2], {})
         jail(self.app, tid, self.mytask.name, [2], {})
-        self.assertIn(tid, Backend._started)
+        self.assertIn(tid, _started)
 
 
         self.mytask.ignore_result = True
         self.mytask.ignore_result = True
         tid = uuid()
         tid = uuid()
         jail(self.app, tid, self.mytask.name, [2], {})
         jail(self.app, tid, self.mytask.name, [2], {})
-        self.assertNotIn(tid, Backend._started)
+        self.assertNotIn(tid, _started)
 
 
     def test_execute_jail_failure(self):
     def test_execute_jail_failure(self):
         ret = jail(
         ret = jail(
@@ -309,18 +307,15 @@ class test_Request(AppCase):
         task.freeze()
         task.freeze()
         req = self.get_request(task)
         req = self.get_request(task)
         self.add.accept_magic_kwargs = True
         self.add.accept_magic_kwargs = True
-        try:
-            pool = Mock()
-            req.execute_using_pool(pool)
-            self.assertTrue(pool.apply_async.called)
-            args = pool.apply_async.call_args[1]['args']
-            self.assertEqual(args[0], task.task)
-            self.assertEqual(args[1], task.id)
-            self.assertEqual(args[2], task.args)
-            kwargs = args[3]
-            self.assertEqual(kwargs.get('task_name'), task.task)
-        finally:
-            self.add.accept_magic_kwargs = False
+        pool = Mock()
+        req.execute_using_pool(pool)
+        self.assertTrue(pool.apply_async.called)
+        args = pool.apply_async.call_args[1]['args']
+        self.assertEqual(args[0], task.task)
+        self.assertEqual(args[1], task.id)
+        self.assertEqual(args[2], task.args)
+        kwargs = args[3]
+        self.assertEqual(kwargs.get('task_name'), task.task)
 
 
     def test_task_wrapper_repr(self):
     def test_task_wrapper_repr(self):
         job = TaskRequest(
         job = TaskRequest(
@@ -697,6 +692,7 @@ class test_Request(AppCase):
         try:
         try:
             self.mytask.__trace__ = build_tracer(
             self.mytask.__trace__ = build_tracer(
                 self.mytask.name, self.mytask, self.app.loader, 'test',
                 self.mytask.name, self.mytask, self.app.loader, 'test',
+                app=self.app,
             )
             )
             res = trace.trace_task_ret(self.mytask.name, uuid(), [4], {})
             res = trace.trace_task_ret(self.mytask.name, uuid(), [4], {})
             self.assertEqual(res, 4 ** 4)
             self.assertEqual(res, 4 ** 4)
@@ -704,24 +700,25 @@ class test_Request(AppCase):
             reset_worker_optimizations()
             reset_worker_optimizations()
             self.assertIs(trace.trace_task_ret, trace._trace_task_ret)
             self.assertIs(trace.trace_task_ret, trace._trace_task_ret)
         delattr(self.mytask, '__trace__')
         delattr(self.mytask, '__trace__')
-        res = trace.trace_task_ret(self.mytask.name, uuid(), [4], {})
+        res = trace.trace_task_ret(
+            self.mytask.name, uuid(), [4], {}, app=self.app,
+        )
         self.assertEqual(res, 4 ** 4)
         self.assertEqual(res, 4 ** 4)
 
 
     def test_trace_task_ret(self):
     def test_trace_task_ret(self):
-        self.app.set_current()   # XXX compat test
         self.mytask.__trace__ = build_tracer(
         self.mytask.__trace__ = build_tracer(
             self.mytask.name, self.mytask, self.app.loader, 'test',
             self.mytask.name, self.mytask, self.app.loader, 'test',
+            app=self.app,
         )
         )
-        res = _trace_task_ret(self.mytask.name, uuid(), [4], {})
+        res = _trace_task_ret(self.mytask.name, uuid(), [4], {}, app=self.app)
         self.assertEqual(res, 4 ** 4)
         self.assertEqual(res, 4 ** 4)
 
 
     def test_trace_task_ret__no_trace(self):
     def test_trace_task_ret__no_trace(self):
-        self.app.set_current()  # XXX compat test
         try:
         try:
             delattr(self.mytask, '__trace__')
             delattr(self.mytask, '__trace__')
         except AttributeError:
         except AttributeError:
             pass
             pass
-        res = _trace_task_ret(self.mytask.name, uuid(), [4], {})
+        res = _trace_task_ret(self.mytask.name, uuid(), [4], {}, app=self.app)
         self.assertEqual(res, 4 ** 4)
         self.assertEqual(res, 4 ** 4)
 
 
     def test_trace_catches_exception(self):
     def test_trace_catches_exception(self):
@@ -735,7 +732,7 @@ class test_Request(AppCase):
 
 
         with self.assertWarnsRegex(RuntimeWarning,
         with self.assertWarnsRegex(RuntimeWarning,
                                    r'Exception raised outside'):
                                    r'Exception raised outside'):
-            res = trace_task(raising, uuid(), [], {})
+            res = trace_task(raising, uuid(), [], {}, app=self.app)
             self.assertIsInstance(res, ExceptionInfo)
             self.assertIsInstance(res, ExceptionInfo)
 
 
     def test_worker_task_trace_handle_retry(self):
     def test_worker_task_trace_handle_retry(self):
@@ -865,7 +862,7 @@ class test_Request(AppCase):
     def test_execute_success_some_kwargs(self):
     def test_execute_success_some_kwargs(self):
         scratch = {'task_id': None}
         scratch = {'task_id': None}
 
 
-        @self.app.task(accept_magic_kwargs=True)
+        @self.app.task(shared=False, accept_magic_kwargs=True)
         def mytask_some_kwargs(i, task_id):
         def mytask_some_kwargs(i, task_id):
             scratch['task_id'] = task_id
             scratch['task_id'] = task_id
             return i ** i
             return i ** i

+ 2 - 2
celery/tests/worker/test_revoke.py

@@ -1,10 +1,10 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
 
 
 from celery.worker import state
 from celery.worker import state
-from celery.tests.case import Case
+from celery.tests.case import AppCase
 
 
 
 
-class test_revoked(Case):
+class test_revoked(AppCase):
 
 
     def test_is_working(self):
     def test_is_working(self):
         state.revoked.add('foo')
         state.revoked.add('foo')

+ 8 - 15
celery/tests/worker/test_state.py

@@ -9,30 +9,22 @@ from celery.datastructures import LimitedSet
 from celery.exceptions import SystemTerminate
 from celery.exceptions import SystemTerminate
 from celery.worker import state
 from celery.worker import state
 
 
-from celery.tests.case import Case
+from celery.tests.case import AppCase
 
 
 
 
-class StateResetCase(Case):
+class StateResetCase(AppCase):
 
 
-    def setUp(self):
+    def setup(self):
         self.reset_state()
         self.reset_state()
-        self.on_setup()
 
 
-    def tearDown(self):
+    def teardown(self):
         self.reset_state()
         self.reset_state()
-        self.on_teardown()
 
 
     def reset_state(self):
     def reset_state(self):
         state.active_requests.clear()
         state.active_requests.clear()
         state.revoked.clear()
         state.revoked.clear()
         state.total_count.clear()
         state.total_count.clear()
 
 
-    def on_setup(self):
-        pass
-
-    def on_teardown(self):
-        pass
-
 
 
 class MockShelve(dict):
 class MockShelve(dict):
     filename = None
     filename = None
@@ -54,9 +46,9 @@ class MyPersistent(state.Persistent):
     storage = MockShelve()
     storage = MockShelve()
 
 
 
 
-class test_maybe_shutdown(Case):
+class test_maybe_shutdown(AppCase):
 
 
-    def tearDown(self):
+    def teardown(self):
         state.should_stop = False
         state.should_stop = False
         state.should_terminate = False
         state.should_terminate = False
 
 
@@ -73,7 +65,8 @@ class test_maybe_shutdown(Case):
 
 
 class test_Persistent(StateResetCase):
 class test_Persistent(StateResetCase):
 
 
-    def on_setup(self):
+    def setup(self):
+        self.reset_state()
         self.p = MyPersistent(state, filename='celery-state')
         self.p = MyPersistent(state, filename='celery-state')
 
 
     def test_close_twice(self):
     def test_close_twice(self):

+ 9 - 12
celery/tests/worker/test_strategy.py

@@ -6,7 +6,6 @@ from mock import Mock, patch
 
 
 from kombu.utils.limits import TokenBucket
 from kombu.utils.limits import TokenBucket
 
 
-from celery import Celery
 from celery.worker import state
 from celery.worker import state
 from celery.utils.timeutils import rate
 from celery.utils.timeutils import rate
 
 
@@ -15,6 +14,13 @@ from celery.tests.case import AppCase, body_from_sig
 
 
 class test_default_strategy(AppCase):
 class test_default_strategy(AppCase):
 
 
+    def setup(self):
+        @self.app.task(shared=False)
+        def add(x, y):
+            return x + y
+
+        self.add = add
+
     class Context(object):
     class Context(object):
 
 
         def __init__(self, sig, s, reserved, consumer, message, body):
         def __init__(self, sig, s, reserved, consumer, message, body):
@@ -52,15 +58,6 @@ class test_default_strategy(AppCase):
                 return self.consumer.timer.apply_at.call_args[0][0]
                 return self.consumer.timer.apply_at.call_args[0][0]
             raise ValueError('request not handled')
             raise ValueError('request not handled')
 
 
-    def setup(self):
-        self.c = Celery(set_as_current=False)
-
-        @self.c.task()
-        def add(x, y):
-            return x + y
-
-        self.add = add
-
     @contextmanager
     @contextmanager
     def _context(self, sig,
     def _context(self, sig,
                  rate_limits=True, events=True, utc=True, limit=None):
                  rate_limits=True, events=True, utc=True, limit=None):
@@ -74,11 +71,11 @@ class test_default_strategy(AppCase):
             consumer.task_buckets[sig.task] = bucket
             consumer.task_buckets[sig.task] = bucket
         consumer.disable_rate_limits = not rate_limits
         consumer.disable_rate_limits = not rate_limits
         consumer.event_dispatcher.enabled = events
         consumer.event_dispatcher.enabled = events
-        s = sig.type.start_strategy(self.c, consumer, task_reserved=reserved)
+        s = sig.type.start_strategy(self.app, consumer, task_reserved=reserved)
         self.assertTrue(s)
         self.assertTrue(s)
 
 
         message = Mock()
         message = Mock()
-        body = body_from_sig(self.c, sig, utc=utc)
+        body = body_from_sig(self.app, sig, utc=utc)
 
 
         yield self.Context(sig, s, reserved, consumer, message, body)
         yield self.Context(sig, s, reserved, consumer, message, body)
 
 

+ 60 - 171
celery/tests/worker/test_worker.py

@@ -9,7 +9,7 @@ from threading import Event
 
 
 from billiard.exceptions import WorkerLostError
 from billiard.exceptions import WorkerLostError
 from kombu import Connection
 from kombu import Connection
-from kombu.common import QoS, PREFETCH_COUNT_MAX, ignore_errors
+from kombu.common import QoS, ignore_errors
 from kombu.exceptions import StdChannelError
 from kombu.exceptions import StdChannelError
 from kombu.transport.base import Message
 from kombu.transport.base import Message
 from mock import call, Mock, patch
 from mock import call, Mock, patch
@@ -20,8 +20,6 @@ from celery.concurrency.base import BasePool
 from celery.datastructures import AttributeDict
 from celery.datastructures import AttributeDict
 from celery.exceptions import SystemTerminate, TaskRevokedError
 from celery.exceptions import SystemTerminate, TaskRevokedError
 from celery.five import Empty, range, Queue as FastQueue
 from celery.five import Empty, range, Queue as FastQueue
-from celery.task import task as task_dec
-from celery.task import periodic_task as periodic_task_dec
 from celery.utils import uuid
 from celery.utils import uuid
 from celery.worker import components
 from celery.worker import components
 from celery.worker import consumer
 from celery.worker import consumer
@@ -32,7 +30,7 @@ from celery.utils import worker_direct
 from celery.utils.serialization import pickle
 from celery.utils.serialization import pickle
 from celery.utils.timer2 import Timer
 from celery.utils.timer2 import Timer
 
 
-from celery.tests.case import AppCase, Case, restore_logging
+from celery.tests.case import AppCase, restore_logging
 
 
 
 
 def MockStep(step=None):
 def MockStep(step=None):
@@ -108,16 +106,6 @@ class MockHeart(object):
         self.closed = True
         self.closed = True
 
 
 
 
-@task_dec()
-def foo_task(x, y, z, **kwargs):
-    return x * y * z
-
-
-@periodic_task_dec(run_every=60)
-def foo_periodic_task():
-    return 'foo'
-
-
 def create_message(channel, **data):
 def create_message(channel, **data):
     data.setdefault('id', uuid())
     data.setdefault('id', uuid())
     channel.no_ack_consumers = set()
     channel.no_ack_consumers = set()
@@ -129,117 +117,17 @@ def create_message(channel, **data):
     return m
     return m
 
 
 
 
-class test_QoS(Case):
-
-    class _QoS(QoS):
-        def __init__(self, value):
-            self.value = value
-            QoS.__init__(self, None, value)
-
-        def set(self, value):
-            return value
-
-    def test_qos_increment_decrement(self):
-        qos = self._QoS(10)
-        self.assertEqual(qos.increment_eventually(), 11)
-        self.assertEqual(qos.increment_eventually(3), 14)
-        self.assertEqual(qos.increment_eventually(-30), 14)
-        self.assertEqual(qos.decrement_eventually(7), 7)
-        self.assertEqual(qos.decrement_eventually(), 6)
-
-    def test_qos_disabled_increment_decrement(self):
-        qos = self._QoS(0)
-        self.assertEqual(qos.increment_eventually(), 0)
-        self.assertEqual(qos.increment_eventually(3), 0)
-        self.assertEqual(qos.increment_eventually(-30), 0)
-        self.assertEqual(qos.decrement_eventually(7), 0)
-        self.assertEqual(qos.decrement_eventually(), 0)
-        self.assertEqual(qos.decrement_eventually(10), 0)
-
-    def test_qos_thread_safe(self):
-        qos = self._QoS(10)
-
-        def add():
-            for i in range(1000):
-                qos.increment_eventually()
-
-        def sub():
-            for i in range(1000):
-                qos.decrement_eventually()
-
-        def threaded(funs):
-            from threading import Thread
-            threads = [Thread(target=fun) for fun in funs]
-            for thread in threads:
-                thread.start()
-            for thread in threads:
-                thread.join()
-
-        threaded([add, add])
-        self.assertEqual(qos.value, 2010)
-
-        qos.value = 1000
-        threaded([add, sub])  # n = 2
-        self.assertEqual(qos.value, 1000)
-
-    def test_exceeds_short(self):
-        qos = QoS(Mock(), PREFETCH_COUNT_MAX - 1)
-        qos.update()
-        self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1)
-        qos.increment_eventually()
-        self.assertEqual(qos.value, PREFETCH_COUNT_MAX)
-        qos.increment_eventually()
-        self.assertEqual(qos.value, PREFETCH_COUNT_MAX + 1)
-        qos.decrement_eventually()
-        self.assertEqual(qos.value, PREFETCH_COUNT_MAX)
-        qos.decrement_eventually()
-        self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1)
-
-    def test_consumer_increment_decrement(self):
-        mconsumer = Mock()
-        qos = QoS(mconsumer.qos, 10)
-        qos.update()
-        self.assertEqual(qos.value, 10)
-        mconsumer.qos.assert_called_with(prefetch_count=10)
-        qos.decrement_eventually()
-        qos.update()
-        self.assertEqual(qos.value, 9)
-        mconsumer.qos.assert_called_with(prefetch_count=9)
-        qos.decrement_eventually()
-        self.assertEqual(qos.value, 8)
-        mconsumer.qos.assert_called_with(prefetch_count=9)
-        self.assertIn({'prefetch_count': 9}, mconsumer.qos.call_args)
-
-        # Does not decrement 0 value
-        qos.value = 0
-        qos.decrement_eventually()
-        self.assertEqual(qos.value, 0)
-        qos.increment_eventually()
-        self.assertEqual(qos.value, 0)
-
-    def test_consumer_decrement_eventually(self):
-        mconsumer = Mock()
-        qos = QoS(mconsumer.qos, 10)
-        qos.decrement_eventually()
-        self.assertEqual(qos.value, 9)
-        qos.value = 0
-        qos.decrement_eventually()
-        self.assertEqual(qos.value, 0)
-
-    def test_set(self):
-        mconsumer = Mock()
-        qos = QoS(mconsumer.qos, 10)
-        qos.set(12)
-        self.assertEqual(qos.prev, 12)
-        qos.set(qos.prev)
-
-
 class test_Consumer(AppCase):
 class test_Consumer(AppCase):
 
 
     def setup(self):
     def setup(self):
         self.buffer = FastQueue()
         self.buffer = FastQueue()
         self.timer = Timer()
         self.timer = Timer()
 
 
+        @self.app.task(shared=False)
+        def foo_task(x, y, z):
+            return x * y * z
+        self.foo_task = foo_task
+
     def teardown(self):
     def teardown(self):
         self.timer.stop()
         self.timer.stop()
 
 
@@ -326,7 +214,7 @@ class test_Consumer(AppCase):
         to_timestamp.side_effect = OverflowError()
         to_timestamp.side_effect = OverflowError()
         l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.steps.pop()
         l.steps.pop()
-        m = create_message(Mock(), task=foo_task.name,
+        m = create_message(Mock(), task=self.foo_task.name,
                            args=('2, 2'),
                            args=('2, 2'),
                            kwargs={},
                            kwargs={},
                            eta=datetime.now().isoformat())
                            eta=datetime.now().isoformat())
@@ -344,7 +232,7 @@ class test_Consumer(AppCase):
         l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.event_dispatcher = Mock()
         l.event_dispatcher = Mock()
         l.steps.pop()
         l.steps.pop()
-        m = create_message(Mock(), task=foo_task.name,
+        m = create_message(Mock(), task=self.foo_task.name,
                            args=(1, 2), kwargs='foobarbaz', id=1)
                            args=(1, 2), kwargs='foobarbaz', id=1)
         l.update_strategies()
         l.update_strategies()
         l.event_dispatcher = Mock()
         l.event_dispatcher = Mock()
@@ -383,7 +271,7 @@ class test_Consumer(AppCase):
     def test_receieve_message(self):
     def test_receieve_message(self):
         l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
         l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
         l.event_dispatcher = Mock()
         l.event_dispatcher = Mock()
-        m = create_message(Mock(), task=foo_task.name,
+        m = create_message(Mock(), task=self.foo_task.name,
                            args=[2, 4, 8], kwargs={})
                            args=[2, 4, 8], kwargs={})
         l.update_strategies()
         l.update_strategies()
         callback = self._get_on_message(l)
         callback = self._get_on_message(l)
@@ -391,7 +279,7 @@ class test_Consumer(AppCase):
 
 
         in_bucket = self.buffer.get_nowait()
         in_bucket = self.buffer.get_nowait()
         self.assertIsInstance(in_bucket, Request)
         self.assertIsInstance(in_bucket, Request)
-        self.assertEqual(in_bucket.name, foo_task.name)
+        self.assertEqual(in_bucket.name, self.foo_task.name)
         self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
         self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
         self.assertTrue(self.timer.empty())
         self.assertTrue(self.timer.empty())
 
 
@@ -520,7 +408,7 @@ class test_Consumer(AppCase):
         l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.steps.pop()
         l.steps.pop()
         m = create_message(
         m = create_message(
-            Mock(), task=foo_task.name,
+            Mock(), task=self.foo_task.name,
             eta=(datetime.now() + timedelta(days=1)).isoformat(),
             eta=(datetime.now() + timedelta(days=1)).isoformat(),
             args=[2, 4, 8], kwargs={},
             args=[2, 4, 8], kwargs={},
         )
         )
@@ -539,7 +427,7 @@ class test_Consumer(AppCase):
         items = [entry[2] for entry in self.timer.queue]
         items = [entry[2] for entry in self.timer.queue]
         found = 0
         found = 0
         for item in items:
         for item in items:
-            if item.args[0].name == foo_task.name:
+            if item.args[0].name == self.foo_task.name:
                 found = True
                 found = True
         self.assertTrue(found)
         self.assertTrue(found)
         self.assertGreater(l.qos.value, current_pcount)
         self.assertGreater(l.qos.value, current_pcount)
@@ -570,7 +458,7 @@ class test_Consumer(AppCase):
         l.steps.pop()
         l.steps.pop()
         backend = Mock()
         backend = Mock()
         id = uuid()
         id = uuid()
-        t = create_message(backend, task=foo_task.name, args=[2, 4, 8],
+        t = create_message(backend, task=self.foo_task.name, args=[2, 4, 8],
                            kwargs={}, id=id)
                            kwargs={}, id=id)
         from celery.worker.state import revoked
         from celery.worker.state import revoked
         revoked.add(id)
         revoked.add(id)
@@ -619,7 +507,7 @@ class test_Consumer(AppCase):
         l.event_dispatcher._outbound_buffer = deque()
         l.event_dispatcher._outbound_buffer = deque()
         backend = Mock()
         backend = Mock()
         m = create_message(
         m = create_message(
-            backend, task=foo_task.name,
+            backend, task=self.foo_task.name,
             args=[2, 4, 8], kwargs={},
             args=[2, 4, 8], kwargs={},
             eta=(datetime.now() + timedelta(days=1)).isoformat(),
             eta=(datetime.now() + timedelta(days=1)).isoformat(),
         )
         )
@@ -627,10 +515,8 @@ class test_Consumer(AppCase):
         l.blueprint.start(l)
         l.blueprint.start(l)
         p = l.app.conf.BROKER_CONNECTION_RETRY
         p = l.app.conf.BROKER_CONNECTION_RETRY
         l.app.conf.BROKER_CONNECTION_RETRY = False
         l.app.conf.BROKER_CONNECTION_RETRY = False
-        try:
-            l.blueprint.start(l)
-        finally:
-            l.app.conf.BROKER_CONNECTION_RETRY = p
+        l.blueprint.start(l)
+        l.app.conf.BROKER_CONNECTION_RETRY = p
         l.blueprint.restart(l)
         l.blueprint.restart(l)
         l.event_dispatcher = Mock()
         l.event_dispatcher = Mock()
         callback = self._get_on_message(l)
         callback = self._get_on_message(l)
@@ -641,7 +527,7 @@ class test_Consumer(AppCase):
         eta, priority, entry = in_hold
         eta, priority, entry = in_hold
         task = entry.args[0]
         task = entry.args[0]
         self.assertIsInstance(task, Request)
         self.assertIsInstance(task, Request)
-        self.assertEqual(task.name, foo_task.name)
+        self.assertEqual(task.name, self.foo_task.name)
         self.assertEqual(task.execute(), 2 * 4 * 8)
         self.assertEqual(task.execute(), 2 * 4 * 8)
         with self.assertRaises(Empty):
         with self.assertRaises(Empty):
             self.buffer.get_nowait()
             self.buffer.get_nowait()
@@ -823,6 +709,11 @@ class test_WorkController(AppCase):
         self.logger = worker.logger = Mock()
         self.logger = worker.logger = Mock()
         self.comp_logger = components.logger = Mock()
         self.comp_logger = components.logger = Mock()
 
 
+        @self.app.task(shared=False)
+        def foo_task(x, y, z):
+            return x * y * z
+        self.foo_task = foo_task
+
     def teardown(self):
     def teardown(self):
         from celery import worker
         from celery import worker
         worker.logger = self._logger
         worker.logger = self._logger
@@ -838,15 +729,11 @@ class test_WorkController(AppCase):
 
 
     def test_setup_queues_worker_direct(self):
     def test_setup_queues_worker_direct(self):
         self.app.conf.CELERY_WORKER_DIRECT = True
         self.app.conf.CELERY_WORKER_DIRECT = True
-        _qs, self.app.amqp.__dict__['queues'] = self.app.amqp.queues, Mock()
-        try:
-            self.worker.setup_queues({})
-            self.app.amqp.queues.select_add.assert_called_with(
-                worker_direct(self.worker.hostname),
-            )
-        finally:
-            self.app.amqp.queues = _qs
-            self.app.conf.CELERY_WORKER_DIRECT = False
+        self.app.amqp.__dict__['queues'] = Mock()
+        self.worker.setup_queues({})
+        self.app.amqp.queues.select_add.assert_called_with(
+            worker_direct(self.worker.hostname),
+        )
 
 
     def test_send_worker_shutdown(self):
     def test_send_worker_shutdown(self):
         with patch('celery.signals.worker_shutdown') as ws:
         with patch('celery.signals.worker_shutdown') as ws:
@@ -881,40 +768,42 @@ class test_WorkController(AppCase):
     @patch('celery.platforms.set_mp_process_title')
     @patch('celery.platforms.set_mp_process_title')
     def test_process_initializer(self, set_mp_process_title, _signals):
     def test_process_initializer(self, set_mp_process_title, _signals):
         with restore_logging():
         with restore_logging():
-            from celery import Celery
             from celery import signals
             from celery import signals
             from celery._state import _tls
             from celery._state import _tls
-            from celery.concurrency.processes import process_initializer
-            from celery.concurrency.processes import (WORKER_SIGRESET,
-                                                      WORKER_SIGIGNORE)
+            from celery.concurrency.processes import (
+                process_initializer, WORKER_SIGRESET, WORKER_SIGIGNORE,
+            )
 
 
             def on_worker_process_init(**kwargs):
             def on_worker_process_init(**kwargs):
                 on_worker_process_init.called = True
                 on_worker_process_init.called = True
             on_worker_process_init.called = False
             on_worker_process_init.called = False
             signals.worker_process_init.connect(on_worker_process_init)
             signals.worker_process_init.connect(on_worker_process_init)
 
 
-            loader = Mock()
-            loader.override_backends = {}
-            app = Celery(loader=loader, set_as_current=False)
-            app.loader = loader
-            app.conf = AttributeDict(DEFAULTS)
-            process_initializer(app, 'awesome.worker.com')
-            _signals.ignore.assert_any_call(*WORKER_SIGIGNORE)
-            _signals.reset.assert_any_call(*WORKER_SIGRESET)
-            self.assertTrue(app.loader.init_worker.call_count)
-            self.assertTrue(on_worker_process_init.called)
-            self.assertIs(_tls.current_app, app)
-            set_mp_process_title.assert_called_with(
-                'celeryd', hostname='awesome.worker.com',
-            )
-
-            with patch('celery.app.trace.setup_worker_optimizations') as swo:
-                os.environ['FORKED_BY_MULTIPROCESSING'] = "1"
-                try:
-                    process_initializer(app, 'luke.worker.com')
-                    swo.assert_called_with(app)
-                finally:
-                    os.environ.pop('FORKED_BY_MULTIPROCESSING', None)
+            def Loader(*args, **kwargs):
+                loader = Mock(*args, **kwargs)
+                loader.conf = {}
+                loader.override_backends = {}
+                return loader
+
+            with self.Celery(loader=Loader) as app:
+                app.conf = AttributeDict(DEFAULTS)
+                process_initializer(app, 'awesome.worker.com')
+                _signals.ignore.assert_any_call(*WORKER_SIGIGNORE)
+                _signals.reset.assert_any_call(*WORKER_SIGRESET)
+                self.assertTrue(app.loader.init_worker.call_count)
+                self.assertTrue(on_worker_process_init.called)
+                self.assertIs(_tls.current_app, app)
+                set_mp_process_title.assert_called_with(
+                    'celeryd', hostname='awesome.worker.com',
+                )
+
+                with patch('celery.app.trace.setup_worker_optimizations') as S:
+                    os.environ['FORKED_BY_MULTIPROCESSING'] = "1"
+                    try:
+                        process_initializer(app, 'luke.worker.com')
+                        S.assert_called_with(app)
+                    finally:
+                        os.environ.pop('FORKED_BY_MULTIPROCESSING', None)
 
 
     def test_attrs(self):
     def test_attrs(self):
         worker = self.worker
         worker = self.worker
@@ -976,7 +865,7 @@ class test_WorkController(AppCase):
         worker = self.worker
         worker = self.worker
         worker.pool = Mock()
         worker.pool = Mock()
         backend = Mock()
         backend = Mock()
-        m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
+        m = create_message(backend, task=self.foo_task.name, args=[4, 8, 10],
                            kwargs={})
                            kwargs={})
         task = Request.from_message(m, m.decode(), app=self.app)
         task = Request.from_message(m, m.decode(), app=self.app)
         worker._process_task(task)
         worker._process_task(task)
@@ -988,7 +877,7 @@ class test_WorkController(AppCase):
         worker.pool = Mock()
         worker.pool = Mock()
         worker.pool.apply_async.side_effect = KeyboardInterrupt('Ctrl+C')
         worker.pool.apply_async.side_effect = KeyboardInterrupt('Ctrl+C')
         backend = Mock()
         backend = Mock()
-        m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
+        m = create_message(backend, task=self.foo_task.name, args=[4, 8, 10],
                            kwargs={})
                            kwargs={})
         task = Request.from_message(m, m.decode(), app=self.app)
         task = Request.from_message(m, m.decode(), app=self.app)
         worker.steps = []
         worker.steps = []
@@ -1002,7 +891,7 @@ class test_WorkController(AppCase):
         worker.pool = Mock()
         worker.pool = Mock()
         worker.pool.apply_async.side_effect = SystemTerminate()
         worker.pool.apply_async.side_effect = SystemTerminate()
         backend = Mock()
         backend = Mock()
-        m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
+        m = create_message(backend, task=self.foo_task.name, args=[4, 8, 10],
                            kwargs={})
                            kwargs={})
         task = Request.from_message(m, m.decode(), app=self.app)
         task = Request.from_message(m, m.decode(), app=self.app)
         worker.steps = []
         worker.steps = []
@@ -1016,7 +905,7 @@ class test_WorkController(AppCase):
         worker.pool = Mock()
         worker.pool = Mock()
         worker.pool.apply_async.side_effect = KeyError('some exception')
         worker.pool.apply_async.side_effect = KeyError('some exception')
         backend = Mock()
         backend = Mock()
-        m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
+        m = create_message(backend, task=self.foo_task.name, args=[4, 8, 10],
                            kwargs={})
                            kwargs={})
         task = Request.from_message(m, m.decode(), app=self.app)
         task = Request.from_message(m, m.decode(), app=self.app)
         worker._process_task(task)
         worker._process_task(task)