Browse Source

Tests passing

Ask Solem 11 years ago
parent
commit
dbe733d234
37 changed files with 492 additions and 474 deletions
  1. 1 0
      celery/__init__.py
  2. 1 0
      celery/_state.py
  3. 1 3
      celery/app/control.py
  4. 1 2
      celery/apps/beat.py
  5. 6 8
      celery/beat.py
  6. 1 2
      celery/bin/amqp.py
  7. 3 3
      celery/events/cursesmon.py
  8. 2 3
      celery/loaders/base.py
  9. 1 1
      celery/schedules.py
  10. 1 1
      celery/tests/app/test_app.py
  11. 5 3
      celery/tests/app/test_beat.py
  12. 10 9
      celery/tests/app/test_loaders.py
  13. 8 10
      celery/tests/backends/test_amqp.py
  14. 7 6
      celery/tests/backends/test_backends.py
  15. 61 45
      celery/tests/backends/test_base.py
  16. 10 11
      celery/tests/backends/test_cache.py
  17. 18 19
      celery/tests/backends/test_database.py
  18. 7 7
      celery/tests/backends/test_mongodb.py
  19. 24 25
      celery/tests/backends/test_redis.py
  20. 13 13
      celery/tests/bin/test_beat.py
  21. 2 2
      celery/tests/bin/test_celeryd_detach.py
  22. 3 5
      celery/tests/bin/test_events.py
  23. 34 32
      celery/tests/bin/test_worker.py
  24. 2 18
      celery/tests/case.py
  25. 13 12
      celery/tests/compat_modules/test_sets.py
  26. 4 4
      celery/tests/events/test_cursesmon.py
  27. 6 9
      celery/tests/events/test_snapshot.py
  28. 3 3
      celery/tests/security/case.py
  29. 40 28
      celery/tests/security/test_security.py
  30. 27 28
      celery/tests/tasks/test_chord.py
  31. 5 5
      celery/tests/tasks/test_http.py
  32. 15 14
      celery/tests/tasks/test_result.py
  33. 33 29
      celery/tests/tasks/test_tasks.py
  34. 4 3
      celery/tests/worker/test_control.py
  35. 25 25
      celery/tests/worker/test_loops.py
  36. 54 46
      celery/tests/worker/test_request.py
  37. 41 40
      celery/tests/worker/test_worker.py

+ 1 - 0
celery/__init__.py

@@ -44,6 +44,7 @@ if STATICA_HACK:  # pragma: no cover
     # This is never executed, but tricks static analyzers (PyDev, PyCharm,
     # This is never executed, but tricks static analyzers (PyDev, PyCharm,
     # pylint, etc.) into knowing the types of these symbols, and what
     # pylint, etc.) into knowing the types of these symbols, and what
     # they contain.
     # they contain.
+    from celery.app import shared_task                   # noqa
     from celery.app.base import Celery                   # noqa
     from celery.app.base import Celery                   # noqa
     from celery.app.utils import bugreport               # noqa
     from celery.app.utils import bugreport               # noqa
     from celery.app.task import Task                     # noqa
     from celery.app.task import Task                     # noqa

+ 1 - 0
celery/_state.py

@@ -55,6 +55,7 @@ def _get_current_app():
 C_STRICT_APP = os.environ.get('C_STRICT_APP')
 C_STRICT_APP = os.environ.get('C_STRICT_APP')
 if os.environ.get('C_STRICT_APP'):  # pragma: no cover
 if os.environ.get('C_STRICT_APP'):  # pragma: no cover
     def get_current_app():
     def get_current_app():
+        raise Exception('USES CURRENT APP')
         import traceback
         import traceback
         print('-- USES CURRENT_APP', file=sys.stderr)  # noqa+
         print('-- USES CURRENT_APP', file=sys.stderr)  # noqa+
         traceback.print_stack(file=sys.stderr)
         traceback.print_stack(file=sys.stderr)

+ 1 - 3
celery/app/control.py

@@ -16,8 +16,6 @@ from kombu.utils import cached_property
 
 
 from celery.exceptions import DuplicateNodenameWarning
 from celery.exceptions import DuplicateNodenameWarning
 
 
-from . import app_or_default
-
 W_DUPNODE = """\
 W_DUPNODE = """\
 Received multiple replies from node name {0!r}.
 Received multiple replies from node name {0!r}.
 Please make sure you give each node a unique nodename using the `-n` option.\
 Please make sure you give each node a unique nodename using the `-n` option.\
@@ -121,7 +119,7 @@ class Control(object):
     Mailbox = Mailbox
     Mailbox = Mailbox
 
 
     def __init__(self, app=None):
     def __init__(self, app=None):
-        self.app = app_or_default(app)
+        self.app = app
         self.mailbox = self.Mailbox('celery', type='fanout',
         self.mailbox = self.Mailbox('celery', type='fanout',
                                     accept=self.app.conf.CELERY_ACCEPT_CONTENT)
                                     accept=self.app.conf.CELERY_ACCEPT_CONTENT)
 
 

+ 1 - 2
celery/apps/beat.py

@@ -16,7 +16,6 @@ import socket
 import sys
 import sys
 
 
 from celery import VERSION_BANNER, platforms, beat
 from celery import VERSION_BANNER, platforms, beat
-from celery.app import app_or_default
 from celery.utils.imports import qualname
 from celery.utils.imports import qualname
 from celery.utils.log import LOG_LEVELS, get_logger
 from celery.utils.log import LOG_LEVELS, get_logger
 from celery.utils.timeutils import humanize_seconds
 from celery.utils.timeutils import humanize_seconds
@@ -44,7 +43,7 @@ class Beat(object):
                  scheduler_cls=None, redirect_stdouts=None,
                  scheduler_cls=None, redirect_stdouts=None,
                  redirect_stdouts_level=None, **kwargs):
                  redirect_stdouts_level=None, **kwargs):
         """Starts the beat task scheduler."""
         """Starts the beat task scheduler."""
-        self.app = app = app_or_default(app or self.app)
+        self.app = app = app or self.app
         self.loglevel = self._getopt('log_level', loglevel)
         self.loglevel = self._getopt('log_level', loglevel)
         self.logfile = self._getopt('log_file', logfile)
         self.logfile = self._getopt('log_file', logfile)
         self.schedule = self._getopt('schedule_filename', schedule)
         self.schedule = self._getopt('schedule_filename', schedule)

+ 6 - 8
celery/beat.py

@@ -25,7 +25,6 @@ from . import __version__
 from . import platforms
 from . import platforms
 from . import signals
 from . import signals
 from . import current_app
 from . import current_app
-from .app import app_or_default
 from .five import items, reraise, values
 from .five import items, reraise, values
 from .schedules import maybe_schedule, crontab
 from .schedules import maybe_schedule, crontab
 from .utils.imports import instantiate
 from .utils.imports import instantiate
@@ -135,7 +134,6 @@ class Scheduler(object):
     :keyword max_interval: see :attr:`max_interval`.
     :keyword max_interval: see :attr:`max_interval`.
 
 
     """
     """
-
     Entry = ScheduleEntry
     Entry = ScheduleEntry
 
 
     #: The schedule dict/shelve.
     #: The schedule dict/shelve.
@@ -151,9 +149,9 @@ class Scheduler(object):
 
 
     logger = logger  # compat
     logger = logger  # compat
 
 
-    def __init__(self, schedule=None, max_interval=None,
-                 app=None, Publisher=None, lazy=False, **kwargs):
-        app = self.app = app_or_default(app)
+    def __init__(self, app, schedule=None, max_interval=None,
+                 Publisher=None, lazy=False, **kwargs):
+        self.app = app
         self.data = maybe_promise({} if schedule is None else schedule)
         self.data = maybe_promise({} if schedule is None else schedule)
         self.max_interval = (max_interval
         self.max_interval = (max_interval
                              or app.conf.CELERYBEAT_MAX_LOOP_INTERVAL
                              or app.conf.CELERYBEAT_MAX_LOOP_INTERVAL
@@ -398,9 +396,9 @@ class PersistentScheduler(Scheduler):
 class Service(object):
 class Service(object):
     scheduler_cls = PersistentScheduler
     scheduler_cls = PersistentScheduler
 
 
-    def __init__(self, max_interval=None, schedule_filename=None,
-                 scheduler_cls=None, app=None):
-        app = self.app = app_or_default(app)
+    def __init__(self, app, max_interval=None, schedule_filename=None,
+                 scheduler_cls=None):
+        self.app = app
         self.max_interval = (max_interval
         self.max_interval = (max_interval
                              or app.conf.CELERYBEAT_MAX_LOOP_INTERVAL)
                              or app.conf.CELERYBEAT_MAX_LOOP_INTERVAL)
         self.scheduler_cls = scheduler_cls or self.scheduler_cls
         self.scheduler_cls = scheduler_cls or self.scheduler_cls

+ 1 - 2
celery/bin/amqp.py

@@ -18,7 +18,6 @@ from itertools import count
 
 
 from amqp import Message
 from amqp import Message
 
 
-from celery.app import app_or_default
 from celery.utils.functional import padlist
 from celery.utils.functional import padlist
 
 
 from celery.bin.base import Command
 from celery.bin.base import Command
@@ -328,7 +327,7 @@ class AMQPAdmin(object):
     Shell = AMQShell
     Shell = AMQShell
 
 
     def __init__(self, *args, **kwargs):
     def __init__(self, *args, **kwargs):
-        self.app = app_or_default(kwargs.get('app'))
+        self.app = kwargs['app']
         self.out = kwargs.setdefault('out', sys.stderr)
         self.out = kwargs.setdefault('out', sys.stderr)
         self.silent = kwargs.get('silent')
         self.silent = kwargs.get('silent')
         self.args = args
         self.args = args

+ 3 - 3
celery/events/cursesmon.py

@@ -56,8 +56,8 @@ class CursesMonitor(object):  # pragma: no cover
     greet = 'celery events {0}'.format(VERSION_BANNER)
     greet = 'celery events {0}'.format(VERSION_BANNER)
     info_str = 'Info: '
     info_str = 'Info: '
 
 
-    def __init__(self, state, keymap=None, app=None):
-        self.app = app_or_default(app)
+    def __init__(self, state, app, keymap=None):
+        self.app = app
         self.keymap = keymap or self.keymap
         self.keymap = keymap or self.keymap
         self.state = state
         self.state = state
         default_keymap = {'J': self.move_selection_down,
         default_keymap = {'J': self.move_selection_down,
@@ -521,7 +521,7 @@ def capture_events(app, state, display):  # pragma: no cover
 def evtop(app=None):  # pragma: no cover
 def evtop(app=None):  # pragma: no cover
     app = app_or_default(app)
     app = app_or_default(app)
     state = app.events.State()
     state = app.events.State()
-    display = CursesMonitor(state, app=app)
+    display = CursesMonitor(state, app)
     display.init_screen()
     display.init_screen()
     refresher = DisplayThread(display)
     refresher = DisplayThread(display)
     refresher.start()
     refresher.start()

+ 2 - 3
celery/loaders/base.py

@@ -64,9 +64,8 @@ class BaseLoader(object):
 
 
     _conf = None
     _conf = None
 
 
-    def __init__(self, app=None, **kwargs):
-        from celery.app import app_or_default
-        self.app = app_or_default(app)
+    def __init__(self, app, **kwargs):
+        self.app = app
         self.task_modules = set()
         self.task_modules = set()
 
 
     def now(self, utc=True):
     def now(self, utc=True):

+ 1 - 1
celery/schedules.py

@@ -529,7 +529,7 @@ class crontab(schedule):
                     other.day_of_week == self.day_of_week and
                     other.day_of_week == self.day_of_week and
                     other.hour == self.hour and
                     other.hour == self.hour and
                     other.minute == self.minute)
                     other.minute == self.minute)
-        return other is self
+        return NotImplemented
 
 
     def __ne__(self, other):
     def __ne__(self, other):
         return not self.__eq__(other)
         return not self.__eq__(other)

+ 1 - 1
celery/tests/app/test_app.py

@@ -513,7 +513,7 @@ class test_App(Case):
             def mail_admins(*args, **kwargs):
             def mail_admins(*args, **kwargs):
                 return args, kwargs
                 return args, kwargs
 
 
-        self.app.loader = Loader()
+        self.app.loader = Loader(app=self.app)
         self.app.conf.ADMINS = None
         self.app.conf.ADMINS = None
         self.assertFalse(self.app.mail_admins('Subject', 'Body'))
         self.assertFalse(self.app.mail_admins('Subject', 'Body'))
         self.app.conf.ADMINS = [('George Costanza', 'george@vandelay.com')]
         self.app.conf.ADMINS = [('George Costanza', 'george@vandelay.com')]

+ 5 - 3
celery/tests/app/test_beat.py

@@ -379,7 +379,9 @@ class test_PersistentScheduler(AppCase):
         s._store.clear.assert_called_with()
         s._store.clear.assert_called_with()
 
 
     def test_get_schedule(self):
     def test_get_schedule(self):
-        s = create_persistent_scheduler()[0](schedule_filename='schedule')
+        s = create_persistent_scheduler()[0](
+            schedule_filename='schedule', app=self.app,
+        )
         s._store = {'entries': {}}
         s._store = {'entries': {}}
         s.schedule = {'foo': 'bar'}
         s.schedule = {'foo': 'bar'}
         self.assertDictEqual(s.schedule, {'foo': 'bar'})
         self.assertDictEqual(s.schedule, {'foo': 'bar'})
@@ -455,7 +457,7 @@ class test_EmbeddedService(AppCase):
 
 
         from billiard.process import Process
         from billiard.process import Process
 
 
-        s = beat.EmbeddedService()
+        s = beat.EmbeddedService(app=self.app)
         self.assertIsInstance(s, Process)
         self.assertIsInstance(s, Process)
         self.assertIsInstance(s.service, beat.Service)
         self.assertIsInstance(s.service, beat.Service)
         s.service = MockService()
         s.service = MockService()
@@ -475,7 +477,7 @@ class test_EmbeddedService(AppCase):
         self.assertTrue(s._popen.terminated)
         self.assertTrue(s._popen.terminated)
 
 
     def test_start_stop_threaded(self):
     def test_start_stop_threaded(self):
-        s = beat.EmbeddedService(thread=True)
+        s = beat.EmbeddedService(thread=True, app=self.app)
         from threading import Thread
         from threading import Thread
         self.assertIsInstance(s, Thread)
         self.assertIsInstance(s, Thread)
         self.assertIsInstance(s.service, beat.Service)
         self.assertIsInstance(s.service, beat.Service)

+ 10 - 9
celery/tests/app/test_loaders.py

@@ -70,7 +70,8 @@ class test_LoaderBase(AppCase):
 
 
     def test_read_configuration_no_env(self):
     def test_read_configuration_no_env(self):
         self.assertDictEqual(
         self.assertDictEqual(
-            base.BaseLoader().read_configuration('FOO_X_S_WE_WQ_Q_WE'),
+            base.BaseLoader(app=self.app).read_configuration(
+                'FOO_X_S_WE_WQ_Q_WE'),
             {},
             {},
         )
         )
 
 
@@ -146,7 +147,7 @@ class test_LoaderBase(AppCase):
 
 
     def test_mail_attribute(self):
     def test_mail_attribute(self):
         from celery.utils import mail
         from celery.utils import mail
-        loader = base.BaseLoader()
+        loader = base.BaseLoader(app=self.app)
         self.assertIs(loader.mail, mail)
         self.assertIs(loader.mail, mail)
 
 
     def test_cmdline_config_ValueError(self):
     def test_cmdline_config_ValueError(self):
@@ -154,12 +155,12 @@ class test_LoaderBase(AppCase):
             self.loader.cmdline_config_parser(['broker.port=foobar'])
             self.loader.cmdline_config_parser(['broker.port=foobar'])
 
 
 
 
-class test_DefaultLoader(Case):
+class test_DefaultLoader(AppCase):
 
 
     @patch('celery.loaders.base.find_module')
     @patch('celery.loaders.base.find_module')
     def test_read_configuration_not_a_package(self, find_module):
     def test_read_configuration_not_a_package(self, find_module):
         find_module.side_effect = NotAPackage()
         find_module.side_effect = NotAPackage()
-        l = default.Loader()
+        l = default.Loader(app=self.app)
         with self.assertRaises(NotAPackage):
         with self.assertRaises(NotAPackage):
             l.read_configuration()
             l.read_configuration()
 
 
@@ -169,7 +170,7 @@ class test_DefaultLoader(Case):
         os.environ['CELERY_CONFIG_MODULE'] = 'celeryconfig.py'
         os.environ['CELERY_CONFIG_MODULE'] = 'celeryconfig.py'
         try:
         try:
             find_module.side_effect = NotAPackage()
             find_module.side_effect = NotAPackage()
-            l = default.Loader()
+            l = default.Loader(app=self.app)
             with self.assertRaises(NotAPackage):
             with self.assertRaises(NotAPackage):
                 l.read_configuration()
                 l.read_configuration()
         finally:
         finally:
@@ -179,7 +180,7 @@ class test_DefaultLoader(Case):
     def test_read_configuration_importerror(self, find_module):
     def test_read_configuration_importerror(self, find_module):
         default.C_WNOCONF = True
         default.C_WNOCONF = True
         find_module.side_effect = ImportError()
         find_module.side_effect = ImportError()
-        l = default.Loader()
+        l = default.Loader(app=self.app)
         with self.assertWarnsRegex(NotConfigured, r'make sure it exists'):
         with self.assertWarnsRegex(NotConfigured, r'make sure it exists'):
             l.read_configuration()
             l.read_configuration()
         default.C_WNOCONF = False
         default.C_WNOCONF = False
@@ -198,7 +199,7 @@ class test_DefaultLoader(Case):
         prevconfig = sys.modules.get(configname)
         prevconfig = sys.modules.get(configname)
         sys.modules[configname] = celeryconfig
         sys.modules[configname] = celeryconfig
         try:
         try:
-            l = default.Loader()
+            l = default.Loader(app=self.app)
             settings = l.read_configuration()
             settings = l.read_configuration()
             self.assertTupleEqual(settings.CELERY_IMPORTS, ('os', 'sys'))
             self.assertTupleEqual(settings.CELERY_IMPORTS, ('os', 'sys'))
             settings = l.read_configuration()
             settings = l.read_configuration()
@@ -209,7 +210,7 @@ class test_DefaultLoader(Case):
                 sys.modules[configname] = prevconfig
                 sys.modules[configname] = prevconfig
 
 
     def test_import_from_cwd(self):
     def test_import_from_cwd(self):
-        l = default.Loader()
+        l = default.Loader(app=self.app)
         old_path = list(sys.path)
         old_path = list(sys.path)
         try:
         try:
             sys.path.remove(os.getcwd())
             sys.path.remove(os.getcwd())
@@ -234,7 +235,7 @@ class test_DefaultLoader(Case):
                 raise ImportError(name)
                 raise ImportError(name)
 
 
         with warnings.catch_warnings(record=True):
         with warnings.catch_warnings(record=True):
-            l = _Loader()
+            l = _Loader(app=self.app)
             self.assertFalse(l.configured)
             self.assertFalse(l.configured)
             context_executed[0] = True
             context_executed[0] = True
         self.assertTrue(context_executed[0])
         self.assertTrue(context_executed[0])

+ 8 - 10
celery/tests/backends/test_amqp.py

@@ -10,9 +10,7 @@ from pickle import dumps, loads
 from billiard.einfo import ExceptionInfo
 from billiard.einfo import ExceptionInfo
 from mock import patch
 from mock import patch
 
 
-from celery import current_app
 from celery import states
 from celery import states
-from celery.app import app_or_default
 from celery.backends.amqp import AMQPBackend
 from celery.backends.amqp import AMQPBackend
 from celery.exceptions import TimeoutError
 from celery.exceptions import TimeoutError
 from celery.five import Empty, Queue, range
 from celery.five import Empty, Queue, range
@@ -31,7 +29,7 @@ class test_AMQPBackend(AppCase):
 
 
     def create_backend(self, **opts):
     def create_backend(self, **opts):
         opts = dict(dict(serializer='pickle', persistent=False), **opts)
         opts = dict(dict(serializer='pickle', persistent=False), **opts)
-        return AMQPBackend(**opts)
+        return AMQPBackend(self.app, **opts)
 
 
     def test_mark_as_done(self):
     def test_mark_as_done(self):
         tb1 = self.create_backend()
         tb1 = self.create_backend()
@@ -107,7 +105,7 @@ class test_AMQPBackend(AppCase):
             iterations[0] += 1
             iterations[0] += 1
             raise KeyError('foo')
             raise KeyError('foo')
 
 
-        backend = AMQPBackend()
+        backend = AMQPBackend(self.app)
         from celery.app.amqp import TaskProducer
         from celery.app.amqp import TaskProducer
         prod, TaskProducer.publish = TaskProducer.publish, publish
         prod, TaskProducer.publish = TaskProducer.publish, publish
         try:
         try:
@@ -172,7 +170,7 @@ class test_AMQPBackend(AppCase):
         class MockBackend(AMQPBackend):
         class MockBackend(AMQPBackend):
             Queue = MockBinding
             Queue = MockBinding
 
 
-        backend = MockBackend()
+        backend = MockBackend(self.app)
         backend._republish = Mock()
         backend._republish = Mock()
 
 
         yield results, backend, Message
         yield results, backend, Message
@@ -251,7 +249,7 @@ class test_AMQPBackend(AppCase):
                 pass
                 pass
 
 
         b = self.create_backend()
         b = self.create_backend()
-        with current_app.pool.acquire_channel(block=False) as (_, channel):
+        with self.app.pool.acquire_channel(block=False) as (_, channel):
             binding = b._create_binding(uuid())
             binding = b._create_binding(uuid())
             consumer = b.Consumer(channel, binding, no_ack=True)
             consumer = b.Consumer(channel, binding, no_ack=True)
             with self.assertRaises(socket.timeout):
             with self.assertRaises(socket.timeout):
@@ -296,14 +294,14 @@ class test_AMQPBackend(AppCase):
             def Consumer(*args, **kwargs):
             def Consumer(*args, **kwargs):
                 raise KeyError('foo')
                 raise KeyError('foo')
 
 
-        b = Backend()
+        b = Backend(self.app)
         with self.assertRaises(KeyError):
         with self.assertRaises(KeyError):
             next(b.get_many(['id1']))
             next(b.get_many(['id1']))
 
 
     def test_get_many_raises_inner_block(self):
     def test_get_many_raises_inner_block(self):
         with patch('kombu.connection.Connection.drain_events') as drain:
         with patch('kombu.connection.Connection.drain_events') as drain:
             drain.side_effect = KeyError('foo')
             drain.side_effect = KeyError('foo')
-            b = AMQPBackend()
+            b = AMQPBackend(self.app)
             with self.assertRaises(KeyError):
             with self.assertRaises(KeyError):
                 next(b.get_many(['id1']))
                 next(b.get_many(['id1']))
 
 
@@ -314,13 +312,13 @@ class test_AMQPBackend(AppCase):
                 drain.side_effect = ValueError()
                 drain.side_effect = ValueError()
                 raise KeyError('foo')
                 raise KeyError('foo')
             drain.side_effect = se
             drain.side_effect = se
-            b = AMQPBackend()
+            b = AMQPBackend(self.app)
             with self.assertRaises(ValueError):
             with self.assertRaises(ValueError):
                 next(b.consume('id1'))
                 next(b.consume('id1'))
 
 
     def test_no_expires(self):
     def test_no_expires(self):
         b = self.create_backend(expires=None)
         b = self.create_backend(expires=None)
-        app = app_or_default()
+        app = self.app
         prev = app.conf.CELERY_TASK_RESULT_EXPIRES
         prev = app.conf.CELERY_TASK_RESULT_EXPIRES
         app.conf.CELERY_TASK_RESULT_EXPIRES = None
         app.conf.CELERY_TASK_RESULT_EXPIRES = None
         try:
         try:

+ 7 - 6
celery/tests/backends/test_backends.py

@@ -2,21 +2,22 @@ from __future__ import absolute_import
 
 
 from mock import patch
 from mock import patch
 
 
-from celery import current_app
 from celery import backends
 from celery import backends
 from celery.backends.amqp import AMQPBackend
 from celery.backends.amqp import AMQPBackend
 from celery.backends.cache import CacheBackend
 from celery.backends.cache import CacheBackend
-from celery.tests.case import Case
+from celery.tests.case import AppCase
 
 
 
 
-class test_backends(Case):
+class test_backends(AppCase):
 
 
     def test_get_backend_aliases(self):
     def test_get_backend_aliases(self):
         expects = [('amqp', AMQPBackend),
         expects = [('amqp', AMQPBackend),
                    ('cache', CacheBackend)]
                    ('cache', CacheBackend)]
         for expect_name, expect_cls in expects:
         for expect_name, expect_cls in expects:
-            self.assertIsInstance(backends.get_backend_cls(expect_name)(),
-                                  expect_cls)
+            self.assertIsInstance(
+                backends.get_backend_cls(expect_name)(app=self.app),
+                expect_cls,
+            )
 
 
     def test_get_backend_cache(self):
     def test_get_backend_cache(self):
         backends.get_backend_cls.clear()
         backends.get_backend_cls.clear()
@@ -32,7 +33,7 @@ class test_backends(Case):
             backends.get_backend_cls('fasodaopjeqijwqe')
             backends.get_backend_cls('fasodaopjeqijwqe')
 
 
     def test_default_backend(self):
     def test_default_backend(self):
-        self.assertEqual(backends.default_backend, current_app.backend)
+        self.assertEqual(backends.default_backend, self.app.backend)
 
 
     def test_backend_by_url(self, url='redis://localhost/1'):
     def test_backend_by_url(self, url='redis://localhost/1'):
         from celery.backends.redis import RedisBackend
         from celery.backends.redis import RedisBackend

+ 61 - 45
celery/tests/backends/test_base.py

@@ -7,7 +7,6 @@ from contextlib import contextmanager
 from mock import Mock, patch
 from mock import Mock, patch
 from nose import SkipTest
 from nose import SkipTest
 
 
-from celery import current_app
 from celery.exceptions import ChordError
 from celery.exceptions import ChordError
 from celery.five import items, range
 from celery.five import items, range
 from celery.result import AsyncResult, GroupResult
 from celery.result import AsyncResult, GroupResult
@@ -40,10 +39,9 @@ else:
 Unpickleable = subclass_exception('Unpickleable', KeyError, 'foo.module')
 Unpickleable = subclass_exception('Unpickleable', KeyError, 'foo.module')
 Impossible = subclass_exception('Impossible', object, 'foo.module')
 Impossible = subclass_exception('Impossible', object, 'foo.module')
 Lookalike = subclass_exception('Lookalike', wrapobject, 'foo.module')
 Lookalike = subclass_exception('Lookalike', wrapobject, 'foo.module')
-b = BaseBackend()
 
 
 
 
-class test_serialization(Case):
+class test_serialization(AppCase):
 
 
     def test_create_exception_cls(self):
     def test_create_exception_cls(self):
         self.assertTrue(serialization.create_exception_cls('FooError', 'm'))
         self.assertTrue(serialization.create_exception_cls('FooError', 'm'))
@@ -51,27 +49,32 @@ class test_serialization(Case):
                                                            KeyError))
                                                            KeyError))
 
 
 
 
-class test_BaseBackend_interface(Case):
+class test_BaseBackend_interface(AppCase):
+
+    def setup(self):
+        self.b = BaseBackend(self.app)
 
 
     def test__forget(self):
     def test__forget(self):
         with self.assertRaises(NotImplementedError):
         with self.assertRaises(NotImplementedError):
-            b._forget('SOMExx-N0Nex1stant-IDxx-')
+            self.b._forget('SOMExx-N0Nex1stant-IDxx-')
 
 
     def test_forget(self):
     def test_forget(self):
         with self.assertRaises(NotImplementedError):
         with self.assertRaises(NotImplementedError):
-            b.forget('SOMExx-N0nex1stant-IDxx-')
+            self.b.forget('SOMExx-N0nex1stant-IDxx-')
 
 
     def test_on_chord_part_return(self):
     def test_on_chord_part_return(self):
-        b.on_chord_part_return(None)
+        self.b.on_chord_part_return(None)
 
 
     def test_on_chord_apply(self, unlock='celery.chord_unlock'):
     def test_on_chord_apply(self, unlock='celery.chord_unlock'):
-        p, current_app.tasks[unlock] = current_app.tasks.get(unlock), Mock()
+        p, self.app.tasks[unlock] = self.app.tasks.get(unlock), Mock()
         try:
         try:
-            b.on_chord_apply('dakj221', 'sdokqweok',
-                             result=[AsyncResult(x) for x in [1, 2, 3]])
-            self.assertTrue(current_app.tasks[unlock].apply_async.call_count)
+            self.b.on_chord_apply(
+                'dakj221', 'sdokqweok',
+                result=[self.app.AsyncResult(x) for x in [1, 2, 3]],
+            )
+            self.assertTrue(self.app.tasks[unlock].apply_async.call_count)
         finally:
         finally:
-            current_app.tasks[unlock] = p
+            self.app.tasks[unlock] = p
 
 
 
 
 class test_exception_pickle(Case):
 class test_exception_pickle(Case):
@@ -93,19 +96,22 @@ class test_exception_pickle(Case):
         self.assertIsNone(fnpe(Impossible()))
         self.assertIsNone(fnpe(Impossible()))
 
 
 
 
-class test_prepare_exception(Case):
+class test_prepare_exception(AppCase):
+
+    def setup(self):
+        self.b = BaseBackend(self.app)
 
 
     def test_unpickleable(self):
     def test_unpickleable(self):
-        x = b.prepare_exception(Unpickleable(1, 2, 'foo'))
+        x = self.b.prepare_exception(Unpickleable(1, 2, 'foo'))
         self.assertIsInstance(x, KeyError)
         self.assertIsInstance(x, KeyError)
-        y = b.exception_to_python(x)
+        y = self.b.exception_to_python(x)
         self.assertIsInstance(y, KeyError)
         self.assertIsInstance(y, KeyError)
 
 
     def test_impossible(self):
     def test_impossible(self):
-        x = b.prepare_exception(Impossible())
+        x = self.b.prepare_exception(Impossible())
         self.assertIsInstance(x, UnpickleableExceptionWrapper)
         self.assertIsInstance(x, UnpickleableExceptionWrapper)
         self.assertTrue(str(x))
         self.assertTrue(str(x))
-        y = b.exception_to_python(x)
+        y = self.b.exception_to_python(x)
         self.assertEqual(y.__class__.__name__, 'Impossible')
         self.assertEqual(y.__class__.__name__, 'Impossible')
         if sys.version_info < (2, 5):
         if sys.version_info < (2, 5):
             self.assertTrue(y.__class__.__module__)
             self.assertTrue(y.__class__.__module__)
@@ -113,18 +119,18 @@ class test_prepare_exception(Case):
             self.assertEqual(y.__class__.__module__, 'foo.module')
             self.assertEqual(y.__class__.__module__, 'foo.module')
 
 
     def test_regular(self):
     def test_regular(self):
-        x = b.prepare_exception(KeyError('baz'))
+        x = self.b.prepare_exception(KeyError('baz'))
         self.assertIsInstance(x, KeyError)
         self.assertIsInstance(x, KeyError)
-        y = b.exception_to_python(x)
+        y = self.b.exception_to_python(x)
         self.assertIsInstance(y, KeyError)
         self.assertIsInstance(y, KeyError)
 
 
 
 
 class KVBackend(KeyValueStoreBackend):
 class KVBackend(KeyValueStoreBackend):
     mget_returns_dict = False
     mget_returns_dict = False
 
 
-    def __init__(self, *args, **kwargs):
+    def __init__(self, app, *args, **kwargs):
         self.db = {}
         self.db = {}
-        super(KVBackend, self).__init__()
+        super(KVBackend, self).__init__(app)
 
 
     def get(self, key):
     def get(self, key):
         return self.db.get(key)
         return self.db.get(key)
@@ -160,17 +166,17 @@ class DictBackend(BaseBackend):
         self._data.pop(group_id, None)
         self._data.pop(group_id, None)
 
 
 
 
-class test_BaseBackend_dict(Case):
+class test_BaseBackend_dict(AppCase):
 
 
-    def setUp(self):
-        self.b = DictBackend()
+    def setup(self):
+        self.b = DictBackend(app=self.app)
 
 
     def test_delete_group(self):
     def test_delete_group(self):
         self.b.delete_group('can-delete')
         self.b.delete_group('can-delete')
         self.assertNotIn('can-delete', self.b._data)
         self.assertNotIn('can-delete', self.b._data)
 
 
     def test_prepare_exception_json(self):
     def test_prepare_exception_json(self):
-        x = DictBackend(serializer='json')
+        x = DictBackend(self.app, serializer='json')
         e = x.prepare_exception(KeyError('foo'))
         e = x.prepare_exception(KeyError('foo'))
         self.assertIn('exc_type', e)
         self.assertIn('exc_type', e)
         e = x.exception_to_python(e)
         e = x.exception_to_python(e)
@@ -178,13 +184,13 @@ class test_BaseBackend_dict(Case):
         self.assertEqual(str(e), "'foo'")
         self.assertEqual(str(e), "'foo'")
 
 
     def test_save_group(self):
     def test_save_group(self):
-        b = BaseBackend()
+        b = BaseBackend(self.app)
         b._save_group = Mock()
         b._save_group = Mock()
         b.save_group('foofoo', 'xxx')
         b.save_group('foofoo', 'xxx')
         b._save_group.assert_called_with('foofoo', 'xxx')
         b._save_group.assert_called_with('foofoo', 'xxx')
 
 
     def test_forget_interface(self):
     def test_forget_interface(self):
-        b = BaseBackend()
+        b = BaseBackend(self.app)
         with self.assertRaises(NotImplementedError):
         with self.assertRaises(NotImplementedError):
             b.forget('foo')
             b.forget('foo')
 
 
@@ -230,7 +236,7 @@ class test_BaseBackend_dict(Case):
 class test_KeyValueStoreBackend(AppCase):
 class test_KeyValueStoreBackend(AppCase):
 
 
     def setup(self):
     def setup(self):
-        self.b = KVBackend()
+        self.b = KVBackend(app=self.app)
 
 
     def test_on_chord_part_return(self):
     def test_on_chord_part_return(self):
         assert not self.b.implements_incr
         assert not self.b.implements_incr
@@ -330,7 +336,9 @@ class test_KeyValueStoreBackend(AppCase):
 
 
     def test_chord_part_return_join_raises_task(self):
     def test_chord_part_return_join_raises_task(self):
         with self._chord_part_context(self.b) as (task, deps, callback):
         with self._chord_part_context(self.b) as (task, deps, callback):
-            deps._failed_join_report = lambda: iter([AsyncResult('culprit')])
+            deps._failed_join_report = lambda: iter([
+                self.app.AsyncResult('culprit'),
+            ])
             deps.join_native.side_effect = KeyError('foo')
             deps.join_native.side_effect = KeyError('foo')
             self.b.on_chord_part_return(task)
             self.b.on_chord_part_return(task)
             self.assertTrue(self.b.fail_from_current_stack.called)
             self.assertTrue(self.b.fail_from_current_stack.called)
@@ -340,15 +348,21 @@ class test_KeyValueStoreBackend(AppCase):
             self.assertIn('Dependency culprit raised', str(exc))
             self.assertIn('Dependency culprit raised', str(exc))
 
 
     def test_restore_group_from_json(self):
     def test_restore_group_from_json(self):
-        b = KVBackend(serializer='json')
-        g = GroupResult('group_id', [AsyncResult('a'), AsyncResult('b')])
+        b = KVBackend(serializer='json', app=self.app)
+        g = self.app.GroupResult(
+            'group_id',
+            [self.app.AsyncResult('a'), self.app.AsyncResult('b')],
+        )
         b._save_group(g.id, g)
         b._save_group(g.id, g)
         g2 = b._restore_group(g.id)['result']
         g2 = b._restore_group(g.id)['result']
         self.assertEqual(g2, g)
         self.assertEqual(g2, g)
 
 
     def test_restore_group_from_pickle(self):
     def test_restore_group_from_pickle(self):
-        b = KVBackend(serializer='pickle')
-        g = GroupResult('group_id', [AsyncResult('a'), AsyncResult('b')])
+        b = KVBackend(serializer='pickle', app=self.app)
+        g = self.app.GroupResult(
+            'group_id',
+            [self.app.AsyncResult('a'), self.app.AsyncResult('b')],
+        )
         b._save_group(g.id, g)
         b._save_group(g.id, g)
         g2 = b._restore_group(g.id)['result']
         g2 = b._restore_group(g.id)['result']
         self.assertEqual(g2, g)
         self.assertEqual(g2, g)
@@ -367,7 +381,9 @@ class test_KeyValueStoreBackend(AppCase):
 
 
     def test_save_restore_delete_group(self):
     def test_save_restore_delete_group(self):
         tid = uuid()
         tid = uuid()
-        tsr = GroupResult(tid, [AsyncResult(uuid()) for _ in range(10)])
+        tsr = self.app.GroupResult(
+            tid, [self.app.AsyncResult(uuid()) for _ in range(10)],
+        )
         self.b.save_group(tid, tsr)
         self.b.save_group(tid, tsr)
         self.b.restore_group(tid)
         self.b.restore_group(tid)
         self.assertEqual(self.b.restore_group(tid), tsr)
         self.assertEqual(self.b.restore_group(tid), tsr)
@@ -378,41 +394,41 @@ class test_KeyValueStoreBackend(AppCase):
         self.assertIsNone(self.b.restore_group('xxx-nonexistant'))
         self.assertIsNone(self.b.restore_group('xxx-nonexistant'))
 
 
 
 
-class test_KeyValueStoreBackend_interface(Case):
+class test_KeyValueStoreBackend_interface(AppCase):
 
 
     def test_get(self):
     def test_get(self):
         with self.assertRaises(NotImplementedError):
         with self.assertRaises(NotImplementedError):
-            KeyValueStoreBackend().get('a')
+            KeyValueStoreBackend(self.app).get('a')
 
 
     def test_set(self):
     def test_set(self):
         with self.assertRaises(NotImplementedError):
         with self.assertRaises(NotImplementedError):
-            KeyValueStoreBackend().set('a', 1)
+            KeyValueStoreBackend(self.app).set('a', 1)
 
 
     def test_incr(self):
     def test_incr(self):
         with self.assertRaises(NotImplementedError):
         with self.assertRaises(NotImplementedError):
-            KeyValueStoreBackend().incr('a')
+            KeyValueStoreBackend(self.app).incr('a')
 
 
     def test_cleanup(self):
     def test_cleanup(self):
-        self.assertFalse(KeyValueStoreBackend().cleanup())
+        self.assertFalse(KeyValueStoreBackend(self.app).cleanup())
 
 
     def test_delete(self):
     def test_delete(self):
         with self.assertRaises(NotImplementedError):
         with self.assertRaises(NotImplementedError):
-            KeyValueStoreBackend().delete('a')
+            KeyValueStoreBackend(self.app).delete('a')
 
 
     def test_mget(self):
     def test_mget(self):
         with self.assertRaises(NotImplementedError):
         with self.assertRaises(NotImplementedError):
-            KeyValueStoreBackend().mget(['a'])
+            KeyValueStoreBackend(self.app).mget(['a'])
 
 
     def test_forget(self):
     def test_forget(self):
         with self.assertRaises(NotImplementedError):
         with self.assertRaises(NotImplementedError):
-            KeyValueStoreBackend().forget('a')
+            KeyValueStoreBackend(self.app).forget('a')
 
 
 
 
-class test_DisabledBackend(Case):
+class test_DisabledBackend(AppCase):
 
 
     def test_store_result(self):
     def test_store_result(self):
-        DisabledBackend().store_result()
+        DisabledBackend(self.app).store_result()
 
 
     def test_is_disabled(self):
     def test_is_disabled(self):
         with self.assertRaises(NotImplementedError):
         with self.assertRaises(NotImplementedError):
-            DisabledBackend().get_status('foo')
+            DisabledBackend(self.app).get_status('foo')

+ 10 - 11
celery/tests/backends/test_cache.py

@@ -8,7 +8,6 @@ from contextlib import contextmanager
 from kombu.utils.encoding import str_to_bytes
 from kombu.utils.encoding import str_to_bytes
 from mock import Mock, patch
 from mock import Mock, patch
 
 
-from celery import current_app
 from celery import states
 from celery import states
 from celery.backends.cache import CacheBackend, DummyClient
 from celery.backends.cache import CacheBackend, DummyClient
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
@@ -67,13 +66,13 @@ class test_CacheBackend(AppCase):
             self.assertIsInstance(self.tb.get_result(self.tid), KeyError)
             self.assertIsInstance(self.tb.get_result(self.tid), KeyError)
 
 
     def test_on_chord_apply(self):
     def test_on_chord_apply(self):
-        tb = CacheBackend(backend='memory://')
+        tb = CacheBackend(backend='memory://', app=self.app)
         gid, res = uuid(), [AsyncResult(uuid()) for _ in range(3)]
         gid, res = uuid(), [AsyncResult(uuid()) for _ in range(3)]
         tb.on_chord_apply(gid, {}, result=res)
         tb.on_chord_apply(gid, {}, result=res)
 
 
     @patch('celery.result.GroupResult')
     @patch('celery.result.GroupResult')
     def test_on_chord_part_return(self, setresult):
     def test_on_chord_part_return(self, setresult):
-        tb = CacheBackend(backend='memory://')
+        tb = CacheBackend(backend='memory://', app=self.app)
 
 
         deps = Mock()
         deps = Mock()
         deps.__len__ = Mock()
         deps.__len__ = Mock()
@@ -82,7 +81,7 @@ class test_CacheBackend(AppCase):
         task = Mock()
         task = Mock()
         task.name = 'foobarbaz'
         task.name = 'foobarbaz'
         try:
         try:
-            current_app.tasks['foobarbaz'] = task
+            self.app.tasks['foobarbaz'] = task
             task.request.chord = subtask(task)
             task.request.chord = subtask(task)
 
 
             gid, res = uuid(), [AsyncResult(uuid()) for _ in range(3)]
             gid, res = uuid(), [AsyncResult(uuid()) for _ in range(3)]
@@ -98,7 +97,7 @@ class test_CacheBackend(AppCase):
             deps.delete.assert_called_with()
             deps.delete.assert_called_with()
 
 
         finally:
         finally:
-            current_app.tasks.pop('foobarbaz')
+            self.app.tasks.pop('foobarbaz')
 
 
     def test_mget(self):
     def test_mget(self):
         self.tb.set('foo', 1)
         self.tb.set('foo', 1)
@@ -117,12 +116,12 @@ class test_CacheBackend(AppCase):
         self.tb.process_cleanup()
         self.tb.process_cleanup()
 
 
     def test_expires_as_int(self):
     def test_expires_as_int(self):
-        tb = CacheBackend(backend='memory://', expires=10)
+        tb = CacheBackend(backend='memory://', expires=10, app=self.app)
         self.assertEqual(tb.expires, 10)
         self.assertEqual(tb.expires, 10)
 
 
     def test_unknown_backend_raises_ImproperlyConfigured(self):
     def test_unknown_backend_raises_ImproperlyConfigured(self):
         with self.assertRaises(ImproperlyConfigured):
         with self.assertRaises(ImproperlyConfigured):
-            CacheBackend(backend='unknown://')
+            CacheBackend(backend='unknown://', app=self.app)
 
 
 
 
 class MyMemcachedStringEncodingError(Exception):
 class MyMemcachedStringEncodingError(Exception):
@@ -218,7 +217,7 @@ class test_memcache_key(AppCase, MockCacheMixin):
                     from celery.backends import cache
                     from celery.backends import cache
                     cache._imp = [None]
                     cache._imp = [None]
                     task_id, result = string(uuid()), 42
                     task_id, result = string(uuid()), 42
-                    b = cache.CacheBackend(backend='memcache')
+                    b = cache.CacheBackend(backend='memcache', app=self.app)
                     b.store_result(task_id, result, status=states.SUCCESS)
                     b.store_result(task_id, result, status=states.SUCCESS)
                     self.assertEqual(b.get_result(task_id), result)
                     self.assertEqual(b.get_result(task_id), result)
 
 
@@ -229,7 +228,7 @@ class test_memcache_key(AppCase, MockCacheMixin):
                     from celery.backends import cache
                     from celery.backends import cache
                     cache._imp = [None]
                     cache._imp = [None]
                     task_id, result = str_to_bytes(uuid()), 42
                     task_id, result = str_to_bytes(uuid()), 42
-                    b = cache.CacheBackend(backend='memcache')
+                    b = cache.CacheBackend(backend='memcache', app=self.app)
                     b.store_result(task_id, result, status=states.SUCCESS)
                     b.store_result(task_id, result, status=states.SUCCESS)
                     self.assertEqual(b.get_result(task_id), result)
                     self.assertEqual(b.get_result(task_id), result)
 
 
@@ -239,7 +238,7 @@ class test_memcache_key(AppCase, MockCacheMixin):
                 from celery.backends import cache
                 from celery.backends import cache
                 cache._imp = [None]
                 cache._imp = [None]
                 task_id, result = string(uuid()), 42
                 task_id, result = string(uuid()), 42
-                b = cache.CacheBackend(backend='memcache')
+                b = cache.CacheBackend(backend='memcache', app=self.app)
                 b.store_result(task_id, result, status=states.SUCCESS)
                 b.store_result(task_id, result, status=states.SUCCESS)
                 self.assertEqual(b.get_result(task_id), result)
                 self.assertEqual(b.get_result(task_id), result)
 
 
@@ -249,6 +248,6 @@ class test_memcache_key(AppCase, MockCacheMixin):
                 from celery.backends import cache
                 from celery.backends import cache
                 cache._imp = [None]
                 cache._imp = [None]
                 task_id, result = str_to_bytes(uuid()), 42
                 task_id, result = str_to_bytes(uuid()), 42
-                b = cache.CacheBackend(backend='memcache')
+                b = cache.CacheBackend(backend='memcache', app=self.app)
                 b.store_result(task_id, result, status=states.SUCCESS)
                 b.store_result(task_id, result, status=states.SUCCESS)
                 self.assertEqual(b.get_result(task_id), result)
                 self.assertEqual(b.get_result(task_id), result)

+ 18 - 19
celery/tests/backends/test_database.py

@@ -6,13 +6,12 @@ from nose import SkipTest
 from pickle import loads, dumps
 from pickle import loads, dumps
 
 
 from celery import states
 from celery import states
-from celery.app import app_or_default
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
 from celery.result import AsyncResult
 from celery.result import AsyncResult
 from celery.utils import uuid
 from celery.utils import uuid
 
 
 from celery.tests.case import (
 from celery.tests.case import (
-    Case,
+    AppCase,
     mask_modules,
     mask_modules,
     skip_if_pypy,
     skip_if_pypy,
     skip_if_jython,
     skip_if_jython,
@@ -33,11 +32,11 @@ class SomeClass(object):
         self.data = data
         self.data = data
 
 
 
 
-class test_DatabaseBackend(Case):
+class test_DatabaseBackend(AppCase):
 
 
     @skip_if_pypy
     @skip_if_pypy
     @skip_if_jython
     @skip_if_jython
-    def setUp(self):
+    def setup(self):
         if DatabaseBackend is None:
         if DatabaseBackend is None:
             raise SkipTest('sqlalchemy not installed')
             raise SkipTest('sqlalchemy not installed')
 
 
@@ -62,20 +61,20 @@ class test_DatabaseBackend(Case):
                 _sqlalchemy_installed()
                 _sqlalchemy_installed()
 
 
     def test_missing_dburi_raises_ImproperlyConfigured(self):
     def test_missing_dburi_raises_ImproperlyConfigured(self):
-        conf = app_or_default().conf
+        conf = self.app.conf
         prev, conf.CELERY_RESULT_DBURI = conf.CELERY_RESULT_DBURI, None
         prev, conf.CELERY_RESULT_DBURI = conf.CELERY_RESULT_DBURI, None
         try:
         try:
             with self.assertRaises(ImproperlyConfigured):
             with self.assertRaises(ImproperlyConfigured):
-                DatabaseBackend()
+                DatabaseBackend(app=self.app)
         finally:
         finally:
             conf.CELERY_RESULT_DBURI = prev
             conf.CELERY_RESULT_DBURI = prev
 
 
     def test_missing_task_id_is_PENDING(self):
     def test_missing_task_id_is_PENDING(self):
-        tb = DatabaseBackend()
+        tb = DatabaseBackend(app=self.app)
         self.assertEqual(tb.get_status('xxx-does-not-exist'), states.PENDING)
         self.assertEqual(tb.get_status('xxx-does-not-exist'), states.PENDING)
 
 
     def test_missing_task_meta_is_dict_with_pending(self):
     def test_missing_task_meta_is_dict_with_pending(self):
-        tb = DatabaseBackend()
+        tb = DatabaseBackend(app=self.app)
         self.assertDictContainsSubset({
         self.assertDictContainsSubset({
             'status': states.PENDING,
             'status': states.PENDING,
             'task_id': 'xxx-does-not-exist-at-all',
             'task_id': 'xxx-does-not-exist-at-all',
@@ -84,7 +83,7 @@ class test_DatabaseBackend(Case):
         }, tb.get_task_meta('xxx-does-not-exist-at-all'))
         }, tb.get_task_meta('xxx-does-not-exist-at-all'))
 
 
     def test_mark_as_done(self):
     def test_mark_as_done(self):
-        tb = DatabaseBackend()
+        tb = DatabaseBackend(app=self.app)
 
 
         tid = uuid()
         tid = uuid()
 
 
@@ -96,7 +95,7 @@ class test_DatabaseBackend(Case):
         self.assertEqual(tb.get_result(tid), 42)
         self.assertEqual(tb.get_result(tid), 42)
 
 
     def test_is_pickled(self):
     def test_is_pickled(self):
-        tb = DatabaseBackend()
+        tb = DatabaseBackend(app=self.app)
 
 
         tid2 = uuid()
         tid2 = uuid()
         result = {'foo': 'baz', 'bar': SomeClass(12345)}
         result = {'foo': 'baz', 'bar': SomeClass(12345)}
@@ -107,19 +106,19 @@ class test_DatabaseBackend(Case):
         self.assertEqual(rindb.get('bar').data, 12345)
         self.assertEqual(rindb.get('bar').data, 12345)
 
 
     def test_mark_as_started(self):
     def test_mark_as_started(self):
-        tb = DatabaseBackend()
+        tb = DatabaseBackend(app=self.app)
         tid = uuid()
         tid = uuid()
         tb.mark_as_started(tid)
         tb.mark_as_started(tid)
         self.assertEqual(tb.get_status(tid), states.STARTED)
         self.assertEqual(tb.get_status(tid), states.STARTED)
 
 
     def test_mark_as_revoked(self):
     def test_mark_as_revoked(self):
-        tb = DatabaseBackend()
+        tb = DatabaseBackend(app=self.app)
         tid = uuid()
         tid = uuid()
         tb.mark_as_revoked(tid)
         tb.mark_as_revoked(tid)
         self.assertEqual(tb.get_status(tid), states.REVOKED)
         self.assertEqual(tb.get_status(tid), states.REVOKED)
 
 
     def test_mark_as_retry(self):
     def test_mark_as_retry(self):
-        tb = DatabaseBackend()
+        tb = DatabaseBackend(app=self.app)
         tid = uuid()
         tid = uuid()
         try:
         try:
             raise KeyError('foo')
             raise KeyError('foo')
@@ -132,7 +131,7 @@ class test_DatabaseBackend(Case):
             self.assertEqual(tb.get_traceback(tid), trace)
             self.assertEqual(tb.get_traceback(tid), trace)
 
 
     def test_mark_as_failure(self):
     def test_mark_as_failure(self):
-        tb = DatabaseBackend()
+        tb = DatabaseBackend(app=self.app)
 
 
         tid3 = uuid()
         tid3 = uuid()
         try:
         try:
@@ -146,7 +145,7 @@ class test_DatabaseBackend(Case):
             self.assertEqual(tb.get_traceback(tid3), trace)
             self.assertEqual(tb.get_traceback(tid3), trace)
 
 
     def test_forget(self):
     def test_forget(self):
-        tb = DatabaseBackend(backend='memory://')
+        tb = DatabaseBackend(backend='memory://', app=self.app)
         tid = uuid()
         tid = uuid()
         tb.mark_as_done(tid, {'foo': 'bar'})
         tb.mark_as_done(tid, {'foo': 'bar'})
         tb.mark_as_done(tid, {'foo': 'bar'})
         tb.mark_as_done(tid, {'foo': 'bar'})
@@ -155,15 +154,15 @@ class test_DatabaseBackend(Case):
         self.assertIsNone(x.result)
         self.assertIsNone(x.result)
 
 
     def test_process_cleanup(self):
     def test_process_cleanup(self):
-        tb = DatabaseBackend()
+        tb = DatabaseBackend(app=self.app)
         tb.process_cleanup()
         tb.process_cleanup()
 
 
     def test_reduce(self):
     def test_reduce(self):
-        tb = DatabaseBackend()
+        tb = DatabaseBackend(app=self.app)
         self.assertTrue(loads(dumps(tb)))
         self.assertTrue(loads(dumps(tb)))
 
 
     def test_save__restore__delete_group(self):
     def test_save__restore__delete_group(self):
-        tb = DatabaseBackend()
+        tb = DatabaseBackend(app=self.app)
 
 
         tid = uuid()
         tid = uuid()
         res = {'something': 'special'}
         res = {'something': 'special'}
@@ -178,7 +177,7 @@ class test_DatabaseBackend(Case):
         self.assertIsNone(tb.restore_group('xxx-nonexisting-id'))
         self.assertIsNone(tb.restore_group('xxx-nonexisting-id'))
 
 
     def test_cleanup(self):
     def test_cleanup(self):
-        tb = DatabaseBackend()
+        tb = DatabaseBackend(app=self.app)
         for i in range(10):
         for i in range(10):
             tb.mark_as_done(uuid(), 42)
             tb.mark_as_done(uuid(), 42)
             tb.save_group(uuid(), {'foo': 'bar'})
             tb.save_group(uuid(), {'foo': 'bar'})

+ 7 - 7
celery/tests/backends/test_mongodb.py

@@ -26,7 +26,7 @@ MONGODB_COLLECTION = 'collection1'
 
 
 class test_MongoBackend(AppCase):
 class test_MongoBackend(AppCase):
 
 
-    def setUp(self):
+    def setup(self):
         if pymongo is None:
         if pymongo is None:
             raise SkipTest('pymongo is not installed.')
             raise SkipTest('pymongo is not installed.')
 
 
@@ -36,9 +36,9 @@ class test_MongoBackend(AppCase):
         R['Binary'], module.Binary = module.Binary, Mock()
         R['Binary'], module.Binary = module.Binary, Mock()
         R['datetime'], datetime.datetime = datetime.datetime, Mock()
         R['datetime'], datetime.datetime = datetime.datetime, Mock()
 
 
-        self.backend = MongoBackend()
+        self.backend = MongoBackend(app=self.app)
 
 
-    def tearDown(self):
+    def teardown(self):
         MongoBackend.encode = self._reset['encode']
         MongoBackend.encode = self._reset['encode']
         MongoBackend.decode = self._reset['decode']
         MongoBackend.decode = self._reset['decode']
         module.Binary = self._reset['Binary']
         module.Binary = self._reset['Binary']
@@ -53,7 +53,7 @@ class test_MongoBackend(AppCase):
         prev, module.pymongo = module.pymongo, None
         prev, module.pymongo = module.pymongo, None
         try:
         try:
             with self.assertRaises(ImproperlyConfigured):
             with self.assertRaises(ImproperlyConfigured):
-                MongoBackend()
+                MongoBackend(app=self.app)
         finally:
         finally:
             module.pymongo = prev
             module.pymongo = prev
 
 
@@ -69,14 +69,14 @@ class test_MongoBackend(AppCase):
         MongoBackend(app=celery)
         MongoBackend(app=celery)
 
 
     def test_restore_group_no_entry(self):
     def test_restore_group_no_entry(self):
-        x = MongoBackend()
+        x = MongoBackend(app=self.app)
         x.collection = Mock()
         x.collection = Mock()
         fo = x.collection.find_one = Mock()
         fo = x.collection.find_one = Mock()
         fo.return_value = None
         fo.return_value = None
         self.assertIsNone(x._restore_group('1f3fab'))
         self.assertIsNone(x._restore_group('1f3fab'))
 
 
     def test_reduce(self):
     def test_reduce(self):
-        x = MongoBackend()
+        x = MongoBackend(app=self.app)
         self.assertTrue(loads(dumps(x)))
         self.assertTrue(loads(dumps(x)))
 
 
     def test_get_connection_connection_exists(self):
     def test_get_connection_connection_exists(self):
@@ -311,7 +311,7 @@ class test_MongoBackend(AppCase):
         mock_collection.assert_called_once()
         mock_collection.assert_called_once()
 
 
     def test_get_database_authfailure(self):
     def test_get_database_authfailure(self):
-        x = MongoBackend()
+        x = MongoBackend(app=self.app)
         x._get_connection = Mock()
         x._get_connection = Mock()
         conn = x._get_connection.return_value = {}
         conn = x._get_connection.return_value = {}
         db = conn[x.mongodb_database] = Mock()
         db = conn[x.mongodb_database] = Mock()

+ 24 - 25
celery/tests/backends/test_redis.py

@@ -8,7 +8,6 @@ from pickle import loads, dumps
 
 
 from kombu.utils import cached_property, uuid
 from kombu.utils import cached_property, uuid
 
 
-from celery import current_app
 from celery import states
 from celery import states
 from celery.datastructures import AttributeDict
 from celery.datastructures import AttributeDict
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
@@ -16,7 +15,7 @@ from celery.result import AsyncResult
 from celery.task import subtask
 from celery.task import subtask
 from celery.utils.timeutils import timedelta_seconds
 from celery.utils.timeutils import timedelta_seconds
 
 
-from celery.tests.case import Case
+from celery.tests.case import AppCase
 
 
 
 
 class Redis(object):
 class Redis(object):
@@ -65,7 +64,7 @@ class redis(object):
             pass
             pass
 
 
 
 
-class test_RedisBackend(Case):
+class test_RedisBackend(AppCase):
 
 
     def get_backend(self):
     def get_backend(self):
         from celery.backends import redis
         from celery.backends import redis
@@ -75,7 +74,7 @@ class test_RedisBackend(Case):
 
 
         return RedisBackend
         return RedisBackend
 
 
-    def setUp(self):
+    def setup(self):
         self.Backend = self.get_backend()
         self.Backend = self.get_backend()
 
 
         class MockBackend(self.Backend):
         class MockBackend(self.Backend):
@@ -89,7 +88,7 @@ class test_RedisBackend(Case):
     def test_reduce(self):
     def test_reduce(self):
         try:
         try:
             from celery.backends.redis import RedisBackend
             from celery.backends.redis import RedisBackend
-            x = RedisBackend()
+            x = RedisBackend(app=self.app)
             self.assertTrue(loads(dumps(x)))
             self.assertTrue(loads(dumps(x)))
         except ImportError:
         except ImportError:
             raise SkipTest('redis not installed')
             raise SkipTest('redis not installed')
@@ -97,10 +96,10 @@ class test_RedisBackend(Case):
     def test_no_redis(self):
     def test_no_redis(self):
         self.MockBackend.redis = None
         self.MockBackend.redis = None
         with self.assertRaises(ImproperlyConfigured):
         with self.assertRaises(ImproperlyConfigured):
-            self.MockBackend()
+            self.MockBackend(app=self.app)
 
 
     def test_url(self):
     def test_url(self):
-        x = self.MockBackend('redis://foobar//1')
+        x = self.MockBackend('redis://foobar//1', app=self.app)
         self.assertEqual(x.host, 'foobar')
         self.assertEqual(x.host, 'foobar')
         self.assertEqual(x.db, '1')
         self.assertEqual(x.db, '1')
 
 
@@ -108,54 +107,54 @@ class test_RedisBackend(Case):
         conf = AttributeDict({'CELERY_RESULT_SERIALIZER': 'json',
         conf = AttributeDict({'CELERY_RESULT_SERIALIZER': 'json',
                               'CELERY_MAX_CACHED_RESULTS': 1,
                               'CELERY_MAX_CACHED_RESULTS': 1,
                               'CELERY_TASK_RESULT_EXPIRES': None})
                               'CELERY_TASK_RESULT_EXPIRES': None})
-        prev, current_app.conf = current_app.conf, conf
+        prev, self.app.conf = self.app.conf, conf
         try:
         try:
-            self.MockBackend()
+            self.MockBackend(app=self.app)
         finally:
         finally:
-            current_app.conf = prev
+            self.app.conf = prev
 
 
     def test_expires_defaults_to_config(self):
     def test_expires_defaults_to_config(self):
-        conf = current_app.conf
+        conf = self.app.conf
         prev = conf.CELERY_TASK_RESULT_EXPIRES
         prev = conf.CELERY_TASK_RESULT_EXPIRES
         conf.CELERY_TASK_RESULT_EXPIRES = 10
         conf.CELERY_TASK_RESULT_EXPIRES = 10
         try:
         try:
-            b = self.Backend(expires=None)
+            b = self.Backend(expires=None, app=self.app)
             self.assertEqual(b.expires, 10)
             self.assertEqual(b.expires, 10)
         finally:
         finally:
             conf.CELERY_TASK_RESULT_EXPIRES = prev
             conf.CELERY_TASK_RESULT_EXPIRES = prev
 
 
     def test_expires_is_int(self):
     def test_expires_is_int(self):
-        b = self.Backend(expires=48)
+        b = self.Backend(expires=48, app=self.app)
         self.assertEqual(b.expires, 48)
         self.assertEqual(b.expires, 48)
 
 
     def test_expires_is_None(self):
     def test_expires_is_None(self):
-        b = self.Backend(expires=None)
+        b = self.Backend(expires=None, app=self.app)
         self.assertEqual(b.expires, timedelta_seconds(
         self.assertEqual(b.expires, timedelta_seconds(
-            current_app.conf.CELERY_TASK_RESULT_EXPIRES))
+            self.app.conf.CELERY_TASK_RESULT_EXPIRES))
 
 
     def test_expires_is_timedelta(self):
     def test_expires_is_timedelta(self):
-        b = self.Backend(expires=timedelta(minutes=1))
+        b = self.Backend(expires=timedelta(minutes=1), app=self.app)
         self.assertEqual(b.expires, 60)
         self.assertEqual(b.expires, 60)
 
 
     def test_on_chord_apply(self):
     def test_on_chord_apply(self):
-        self.Backend().on_chord_apply(
+        self.Backend(app=self.app).on_chord_apply(
             'group_id', {},
             'group_id', {},
             result=[AsyncResult(x) for x in [1, 2, 3]],
             result=[AsyncResult(x) for x in [1, 2, 3]],
         )
         )
 
 
     def test_mget(self):
     def test_mget(self):
-        b = self.MockBackend()
+        b = self.MockBackend(app=self.app)
         self.assertTrue(b.mget(['a', 'b', 'c']))
         self.assertTrue(b.mget(['a', 'b', 'c']))
         b.client.mget.assert_called_with(['a', 'b', 'c'])
         b.client.mget.assert_called_with(['a', 'b', 'c'])
 
 
     def test_set_no_expire(self):
     def test_set_no_expire(self):
-        b = self.MockBackend()
+        b = self.MockBackend(app=self.app)
         b.expires = None
         b.expires = None
         b.set('foo', 'bar')
         b.set('foo', 'bar')
 
 
     @patch('celery.result.GroupResult')
     @patch('celery.result.GroupResult')
     def test_on_chord_part_return(self, setresult):
     def test_on_chord_part_return(self, setresult):
-        b = self.MockBackend()
+        b = self.MockBackend(app=self.app)
         deps = Mock()
         deps = Mock()
         deps.__len__ = Mock()
         deps.__len__ = Mock()
         deps.__len__.return_value = 10
         deps.__len__.return_value = 10
@@ -164,7 +163,7 @@ class test_RedisBackend(Case):
         task = Mock()
         task = Mock()
         task.name = 'foobarbaz'
         task.name = 'foobarbaz'
         try:
         try:
-            current_app.tasks['foobarbaz'] = task
+            self.app.tasks['foobarbaz'] = task
             task.request.chord = subtask(task)
             task.request.chord = subtask(task)
             task.request.group = 'group_id'
             task.request.group = 'group_id'
 
 
@@ -178,13 +177,13 @@ class test_RedisBackend(Case):
 
 
             self.assertTrue(b.client.expire.call_count)
             self.assertTrue(b.client.expire.call_count)
         finally:
         finally:
-            current_app.tasks.pop('foobarbaz')
+            self.app.tasks.pop('foobarbaz')
 
 
     def test_process_cleanup(self):
     def test_process_cleanup(self):
-        self.Backend().process_cleanup()
+        self.Backend(app=self.app).process_cleanup()
 
 
     def test_get_set_forget(self):
     def test_get_set_forget(self):
-        b = self.Backend()
+        b = self.Backend(app=self.app)
         tid = uuid()
         tid = uuid()
         b.store_result(tid, 42, states.SUCCESS)
         b.store_result(tid, 42, states.SUCCESS)
         self.assertEqual(b.get_status(tid), states.SUCCESS)
         self.assertEqual(b.get_status(tid), states.SUCCESS)
@@ -193,7 +192,7 @@ class test_RedisBackend(Case):
         self.assertEqual(b.get_status(tid), states.PENDING)
         self.assertEqual(b.get_status(tid), states.PENDING)
 
 
     def test_set_expires(self):
     def test_set_expires(self):
-        b = self.Backend(expires=512)
+        b = self.Backend(expires=512, app=self.app)
         tid = uuid()
         tid = uuid()
         key = b.get_key_for_task(tid)
         key = b.get_key_for_task(tid)
         b.store_result(tid, 42, states.SUCCESS)
         b.store_result(tid, 42, states.SUCCESS)

+ 13 - 13
celery/tests/bin/test_beat.py

@@ -10,7 +10,6 @@ from mock import patch
 
 
 from celery import beat
 from celery import beat
 from celery import platforms
 from celery import platforms
-from celery.app import app_or_default
 from celery.bin import beat as beat_bin
 from celery.bin import beat as beat_bin
 from celery.apps import beat as beatapp
 from celery.apps import beat as beatapp
 
 
@@ -64,10 +63,10 @@ class MockBeat3(beatapp.Beat):
 class test_Beat(AppCase):
 class test_Beat(AppCase):
 
 
     def test_loglevel_string(self):
     def test_loglevel_string(self):
-        b = beatapp.Beat(loglevel='DEBUG')
+        b = beatapp.Beat(app=self.app, loglevel='DEBUG')
         self.assertEqual(b.loglevel, logging.DEBUG)
         self.assertEqual(b.loglevel, logging.DEBUG)
 
 
-        b2 = beatapp.Beat(loglevel=logging.DEBUG)
+        b2 = beatapp.Beat(app=self.app, loglevel=logging.DEBUG)
         self.assertEqual(b2.loglevel, logging.DEBUG)
         self.assertEqual(b2.loglevel, logging.DEBUG)
 
 
     def test_colorize(self):
     def test_colorize(self):
@@ -80,15 +79,15 @@ class test_Beat(AppCase):
         self.assertEqual(app.log.setup.call_args[1]['colorize'], False)
         self.assertEqual(app.log.setup.call_args[1]['colorize'], False)
 
 
     def test_init_loader(self):
     def test_init_loader(self):
-        b = beatapp.Beat()
+        b = beatapp.Beat(app=self.app)
         b.init_loader()
         b.init_loader()
 
 
     def test_process_title(self):
     def test_process_title(self):
-        b = beatapp.Beat()
+        b = beatapp.Beat(app=self.app)
         b.set_process_title()
         b.set_process_title()
 
 
     def test_run(self):
     def test_run(self):
-        b = MockBeat2()
+        b = MockBeat2(app=self.app)
         MockService.started = False
         MockService.started = False
         b.run()
         b.run()
         self.assertTrue(MockService.started)
         self.assertTrue(MockService.started)
@@ -109,8 +108,8 @@ class test_Beat(AppCase):
             platforms.signals = p
             platforms.signals = p
 
 
     def test_install_sync_handler(self):
     def test_install_sync_handler(self):
-        b = beatapp.Beat()
-        clock = MockService()
+        b = beatapp.Beat(app=self.app)
+        clock = MockService(app=self.app)
         MockService.in_sync = False
         MockService.in_sync = False
         handlers = self.psig(b.install_sync_handler, clock)
         handlers = self.psig(b.install_sync_handler, clock)
         with self.assertRaises(SystemExit):
         with self.assertRaises(SystemExit):
@@ -124,7 +123,7 @@ class test_Beat(AppCase):
             delattr(sys.stdout, 'logger')
             delattr(sys.stdout, 'logger')
         except AttributeError:
         except AttributeError:
             pass
             pass
-        b = beatapp.Beat()
+        b = beatapp.Beat(app=self.app)
         b.redirect_stdouts = False
         b.redirect_stdouts = False
         b.app.log.__class__._setup = False
         b.app.log.__class__._setup = False
         b.setup_logging()
         b.setup_logging()
@@ -134,14 +133,15 @@ class test_Beat(AppCase):
     @redirect_stdouts
     @redirect_stdouts
     @patch('celery.apps.beat.logger')
     @patch('celery.apps.beat.logger')
     def test_logs_errors(self, logger, stdout, stderr):
     def test_logs_errors(self, logger, stdout, stderr):
-        b = MockBeat3(socket_timeout=None)
+        b = MockBeat3(app=self.app, socket_timeout=None)
         b.start_scheduler()
         b.start_scheduler()
         self.assertTrue(logger.critical.called)
         self.assertTrue(logger.critical.called)
 
 
     @redirect_stdouts
     @redirect_stdouts
     @patch('celery.platforms.create_pidlock')
     @patch('celery.platforms.create_pidlock')
     def test_use_pidfile(self, create_pidlock, stdout, stderr):
     def test_use_pidfile(self, create_pidlock, stdout, stderr):
-        b = MockBeat2(pidfile='pidfilelockfilepid', socket_timeout=None)
+        b = MockBeat2(app=self.app, pidfile='pidfilelockfilepid',
+                      socket_timeout=None)
         b.start_scheduler()
         b.start_scheduler()
         self.assertTrue(create_pidlock.called)
         self.assertTrue(create_pidlock.called)
 
 
@@ -184,13 +184,13 @@ class test_div(AppCase):
 
 
     def test_detach(self):
     def test_detach(self):
         cmd = beat_bin.beat()
         cmd = beat_bin.beat()
-        cmd.app = app_or_default()
+        cmd.app = self.app
         cmd.run(detach=True)
         cmd.run(detach=True)
         self.assertTrue(MockDaemonContext.opened)
         self.assertTrue(MockDaemonContext.opened)
         self.assertTrue(MockDaemonContext.closed)
         self.assertTrue(MockDaemonContext.closed)
 
 
     def test_parse_options(self):
     def test_parse_options(self):
         cmd = beat_bin.beat()
         cmd = beat_bin.beat()
-        cmd.app = app_or_default()
+        cmd.app = self.app
         options, args = cmd.parse_options('celery beat', ['-s', 'foo'])
         options, args = cmd.parse_options('celery beat', ['-s', 'foo'])
         self.assertEqual(options.schedule, 'foo')
         self.assertEqual(options.schedule, 'foo')

+ 2 - 2
celery/tests/bin/test_celeryd_detach.py

@@ -2,7 +2,7 @@ from __future__ import absolute_import
 
 
 from mock import Mock, patch
 from mock import Mock, patch
 
 
-from celery import current_app
+from celery.platforms import IS_WINDOWS
 from celery.bin.celeryd_detach import (
 from celery.bin.celeryd_detach import (
     detach,
     detach,
     detached_celeryd,
     detached_celeryd,
@@ -12,7 +12,7 @@ from celery.bin.celeryd_detach import (
 from celery.tests.case import Case, override_stdouts
 from celery.tests.case import Case, override_stdouts
 
 
 
 
-if not current_app.IS_WINDOWS:
+if not IS_WINDOWS:
     class test_detached(Case):
     class test_detached(Case):
 
 
         @patch('celery.bin.celeryd_detach.detached')
         @patch('celery.bin.celeryd_detach.detached')

+ 3 - 5
celery/tests/bin/test_events.py

@@ -3,10 +3,9 @@ from __future__ import absolute_import
 from nose import SkipTest
 from nose import SkipTest
 from mock import patch as mpatch
 from mock import patch as mpatch
 
 
-from celery.app import app_or_default
 from celery.bin import events
 from celery.bin import events
 
 
-from celery.tests.case import Case, _old_patch as patch
+from celery.tests.case import AppCase, _old_patch as patch
 
 
 
 
 class MockCommand(object):
 class MockCommand(object):
@@ -21,10 +20,9 @@ def proctitle(prog, info=None):
 proctitle.last = ()
 proctitle.last = ()
 
 
 
 
-class test_events(Case):
+class test_events(AppCase):
 
 
-    def setUp(self):
-        self.app = app_or_default()
+    def setup(self):
         self.ev = events.events(app=self.app)
         self.ev = events.events(app=self.app)
 
 
     @patch('celery.events.dumper', 'evdump', lambda **kw: 'me dumper, you?')
     @patch('celery.events.dumper', 'evdump', lambda **kw: 'me dumper, you?')

+ 34 - 32
celery/tests/bin/test_worker.py

@@ -15,7 +15,6 @@ from kombu import Exchange, Queue
 from celery import Celery
 from celery import Celery
 from celery import platforms
 from celery import platforms
 from celery import signals
 from celery import signals
-from celery import current_app
 from celery.app import trace
 from celery.app import trace
 from celery.apps import worker as cd
 from celery.apps import worker as cd
 from celery.bin.worker import worker, main as worker_main
 from celery.bin.worker import worker, main as worker_main
@@ -152,7 +151,7 @@ class test_Worker(WorkerAppCase):
 
 
     @disable_stdouts
     @disable_stdouts
     def test_loglevel_string(self):
     def test_loglevel_string(self):
-        worker = self.Worker(loglevel='INFO')
+        worker = self.Worker(app=self.app, loglevel='INFO')
         self.assertEqual(worker.loglevel, logging.INFO)
         self.assertEqual(worker.loglevel, logging.INFO)
 
 
     def test_run_worker(self):
     def test_run_worker(self):
@@ -166,14 +165,14 @@ class test_Worker(WorkerAppCase):
         p = platforms.signals
         p = platforms.signals
         platforms.signals = Signals()
         platforms.signals = Signals()
         try:
         try:
-            w = self.Worker()
+            w = self.Worker(app=self.app)
             w._isatty = False
             w._isatty = False
             w.on_start()
             w.on_start()
             for sig in 'SIGINT', 'SIGHUP', 'SIGTERM':
             for sig in 'SIGINT', 'SIGHUP', 'SIGTERM':
                 self.assertIn(sig, handlers)
                 self.assertIn(sig, handlers)
 
 
             handlers.clear()
             handlers.clear()
-            w = self.Worker()
+            w = self.Worker(app=self.app)
             w._isatty = True
             w._isatty = True
             w.on_start()
             w.on_start()
             for sig in 'SIGINT', 'SIGTERM':
             for sig in 'SIGINT', 'SIGTERM':
@@ -184,7 +183,7 @@ class test_Worker(WorkerAppCase):
 
 
     @disable_stdouts
     @disable_stdouts
     def test_startup_info(self):
     def test_startup_info(self):
-        worker = self.Worker()
+        worker = self.Worker(app=self.app)
         worker.on_start()
         worker.on_start()
         self.assertTrue(worker.startup_info())
         self.assertTrue(worker.startup_info())
         worker.loglevel = logging.DEBUG
         worker.loglevel = logging.DEBUG
@@ -211,7 +210,7 @@ class test_Worker(WorkerAppCase):
             app.loader = prev
             app.loader = prev
 
 
         from celery.loaders.app import AppLoader
         from celery.loaders.app import AppLoader
-        prev, app.loader = app.loader, AppLoader()
+        prev, app.loader = app.loader, AppLoader(app=self.app)
         try:
         try:
             self.assertTrue(worker.startup_info())
             self.assertTrue(worker.startup_info())
         finally:
         finally:
@@ -227,18 +226,18 @@ class test_Worker(WorkerAppCase):
 
 
     @disable_stdouts
     @disable_stdouts
     def test_run(self):
     def test_run(self):
-        self.Worker().on_start()
-        self.Worker(purge=True).on_start()
-        worker = self.Worker()
+        self.Worker(app=self.app).on_start()
+        self.Worker(app=self.app, purge=True).on_start()
+        worker = self.Worker(app=self.app)
         worker.on_start()
         worker.on_start()
 
 
     @disable_stdouts
     @disable_stdouts
     def test_purge_messages(self):
     def test_purge_messages(self):
-        self.Worker().purge_messages()
+        self.Worker(app=self.app).purge_messages()
 
 
     @disable_stdouts
     @disable_stdouts
     def test_init_queues(self):
     def test_init_queues(self):
-        app = current_app
+        app = self.app
         c = app.conf
         c = app.conf
         p, app.amqp.queues = app.amqp.queues, app.amqp.Queues({
         p, app.amqp.queues = app.amqp.queues, app.amqp.Queues({
             'celery': {'exchange': 'celery',
             'celery': {'exchange': 'celery',
@@ -246,7 +245,7 @@ class test_Worker(WorkerAppCase):
             'video': {'exchange': 'video',
             'video': {'exchange': 'video',
                       'routing_key': 'video'}})
                       'routing_key': 'video'}})
         try:
         try:
-            worker = self.Worker()
+            worker = self.Worker(app=self.app)
             worker.setup_queues(['video'])
             worker.setup_queues(['video'])
             self.assertIn('video', app.amqp.queues)
             self.assertIn('video', app.amqp.queues)
             self.assertIn('video', app.amqp.queues.consume_from)
             self.assertIn('video', app.amqp.queues.consume_from)
@@ -256,10 +255,10 @@ class test_Worker(WorkerAppCase):
             c.CELERY_CREATE_MISSING_QUEUES = False
             c.CELERY_CREATE_MISSING_QUEUES = False
             del(app.amqp.queues)
             del(app.amqp.queues)
             with self.assertRaises(ImproperlyConfigured):
             with self.assertRaises(ImproperlyConfigured):
-                self.Worker().setup_queues(['image'])
+                self.Worker(app=self.app).setup_queues(['image'])
             del(app.amqp.queues)
             del(app.amqp.queues)
             c.CELERY_CREATE_MISSING_QUEUES = True
             c.CELERY_CREATE_MISSING_QUEUES = True
-            worker = self.Worker()
+            worker = self.Worker(app=self.app)
             worker.setup_queues(queues=['image'])
             worker.setup_queues(queues=['image'])
             self.assertIn('image', app.amqp.queues.consume_from)
             self.assertIn('image', app.amqp.queues.consume_from)
             self.assertEqual(Queue('image', Exchange('image'),
             self.assertEqual(Queue('image', Exchange('image'),
@@ -269,31 +268,32 @@ class test_Worker(WorkerAppCase):
 
 
     @disable_stdouts
     @disable_stdouts
     def test_autoscale_argument(self):
     def test_autoscale_argument(self):
-        worker1 = self.Worker(autoscale='10,3')
+        worker1 = self.Worker(app=self.app, autoscale='10,3')
         self.assertListEqual(worker1.autoscale, [10, 3])
         self.assertListEqual(worker1.autoscale, [10, 3])
-        worker2 = self.Worker(autoscale='10')
+        worker2 = self.Worker(app=self.app, autoscale='10')
         self.assertListEqual(worker2.autoscale, [10, 0])
         self.assertListEqual(worker2.autoscale, [10, 0])
 
 
     def test_include_argument(self):
     def test_include_argument(self):
-        worker1 = self.Worker(include='some.module')
+        worker1 = self.Worker(app=self.app, include='some.module')
         self.assertListEqual(worker1.include, ['some.module'])
         self.assertListEqual(worker1.include, ['some.module'])
-        worker2 = self.Worker(include='some.module,another.package')
+        worker2 = self.Worker(app=self.app,
+                              include='some.module,another.package')
         self.assertListEqual(
         self.assertListEqual(
             worker2.include,
             worker2.include,
             ['some.module', 'another.package'],
             ['some.module', 'another.package'],
         )
         )
-        self.Worker(include=['os', 'sys'])
+        self.Worker(app=self.app, include=['os', 'sys'])
 
 
     @disable_stdouts
     @disable_stdouts
     def test_unknown_loglevel(self):
     def test_unknown_loglevel(self):
         with self.assertRaises(SystemExit):
         with self.assertRaises(SystemExit):
             worker(app=self.app).run(loglevel='ALIEN')
             worker(app=self.app).run(loglevel='ALIEN')
-        worker1 = self.Worker(loglevel=0xFFFF)
+        worker1 = self.Worker(app=self.app, loglevel=0xFFFF)
         self.assertEqual(worker1.loglevel, 0xFFFF)
         self.assertEqual(worker1.loglevel, 0xFFFF)
 
 
     @disable_stdouts
     @disable_stdouts
     def test_warns_if_running_as_privileged_user(self):
     def test_warns_if_running_as_privileged_user(self):
-        app = current_app
+        app = self.app
         if app.IS_WINDOWS:
         if app.IS_WINDOWS:
             raise SkipTest('Not applicable on Windows')
             raise SkipTest('Not applicable on Windows')
 
 
@@ -305,14 +305,14 @@ class test_Worker(WorkerAppCase):
             with self.assertWarnsRegex(
             with self.assertWarnsRegex(
                     RuntimeWarning,
                     RuntimeWarning,
                     r'superuser privileges is discouraged'):
                     r'superuser privileges is discouraged'):
-                worker = self.Worker()
+                worker = self.Worker(app=self.app)
                 worker.on_start()
                 worker.on_start()
         finally:
         finally:
             os.getuid = prev
             os.getuid = prev
 
 
     @disable_stdouts
     @disable_stdouts
     def test_redirect_stdouts(self):
     def test_redirect_stdouts(self):
-        self.Worker(redirect_stdouts=False)
+        self.Worker(app=self.app, redirect_stdouts=False)
         with self.assertRaises(AttributeError):
         with self.assertRaises(AttributeError):
             sys.stdout.logger
             sys.stdout.logger
 
 
@@ -322,7 +322,7 @@ class test_Worker(WorkerAppCase):
             self.app.log.redirect_stdouts, Mock(),
             self.app.log.redirect_stdouts, Mock(),
         )
         )
         try:
         try:
-            worker = self.Worker(redirect_stoutds=True)
+            worker = self.Worker(app=self.app, redirect_stoutds=True)
             worker._custom_logging = True
             worker._custom_logging = True
             worker.on_start()
             worker.on_start()
             self.assertFalse(self.app.log.redirect_stdouts.called)
             self.assertFalse(self.app.log.redirect_stdouts.called)
@@ -330,14 +330,16 @@ class test_Worker(WorkerAppCase):
             self.app.log.redirect_stdouts = prev
             self.app.log.redirect_stdouts = prev
 
 
     def test_setup_logging_no_color(self):
     def test_setup_logging_no_color(self):
-        worker = self.Worker(redirect_stdouts=False, no_color=True)
+        worker = self.Worker(
+            app=self.app, redirect_stdouts=False, no_color=True,
+        )
         prev, self.app.log.setup = self.app.log.setup, Mock()
         prev, self.app.log.setup = self.app.log.setup, Mock()
         worker.setup_logging()
         worker.setup_logging()
         self.assertFalse(self.app.log.setup.call_args[1]['colorize'])
         self.assertFalse(self.app.log.setup.call_args[1]['colorize'])
 
 
     @disable_stdouts
     @disable_stdouts
     def test_startup_info_pool_is_str(self):
     def test_startup_info_pool_is_str(self):
-        worker = self.Worker(redirect_stdouts=False)
+        worker = self.Worker(app=self.app, redirect_stdouts=False)
         worker.pool_cls = 'foo'
         worker.pool_cls = 'foo'
         worker.startup_info()
         worker.startup_info()
 
 
@@ -349,7 +351,7 @@ class test_Worker(WorkerAppCase):
             logging_setup[0] = True
             logging_setup[0] = True
 
 
         try:
         try:
-            worker = self.Worker(redirect_stdouts=False)
+            worker = self.Worker(app=self.app, redirect_stdouts=False)
             worker.app.log.__class__._setup = False
             worker.app.log.__class__._setup = False
             worker.setup_logging()
             worker.setup_logging()
             self.assertTrue(logging_setup[0])
             self.assertTrue(logging_setup[0])
@@ -367,7 +369,7 @@ class test_Worker(WorkerAppCase):
             def osx_proxy_detection_workaround(self):
             def osx_proxy_detection_workaround(self):
                 self.proxy_workaround_installed = True
                 self.proxy_workaround_installed = True
 
 
-        worker = OSXWorker(redirect_stdouts=False)
+        worker = OSXWorker(app=self.app, redirect_stdouts=False)
 
 
         def install_HUP_nosupport(controller):
         def install_HUP_nosupport(controller):
             controller.hup_not_supported_installed = True
             controller.hup_not_supported_installed = True
@@ -400,7 +402,7 @@ class test_Worker(WorkerAppCase):
         prev = cd.install_worker_restart_handler
         prev = cd.install_worker_restart_handler
         cd.install_worker_restart_handler = install_worker_restart_handler
         cd.install_worker_restart_handler = install_worker_restart_handler
         try:
         try:
-            worker = self.Worker()
+            worker = self.Worker(app=self.app)
             worker.app.IS_OSX = False
             worker.app.IS_OSX = False
             worker.install_platform_tweaks(Controller())
             worker.install_platform_tweaks(Controller())
             self.assertTrue(restart_worker_handler_installed[0])
             self.assertTrue(restart_worker_handler_installed[0])
@@ -415,7 +417,7 @@ class test_Worker(WorkerAppCase):
         def on_worker_ready(**kwargs):
         def on_worker_ready(**kwargs):
             worker_ready_sent[0] = True
             worker_ready_sent[0] = True
 
 
-        self.Worker().on_consumer_ready(object())
+        self.Worker(app=self.app).on_consumer_ready(object())
         self.assertTrue(worker_ready_sent[0])
         self.assertTrue(worker_ready_sent[0])
 
 
 
 
@@ -430,7 +432,7 @@ class test_funs(WorkerAppCase):
             __import__('setproctitle')
             __import__('setproctitle')
         except ImportError:
         except ImportError:
             raise SkipTest('setproctitle not installed')
             raise SkipTest('setproctitle not installed')
-        worker = Worker(hostname='xyzza')
+        worker = Worker(app=self.app, hostname='xyzza')
         prev1, sys.argv = sys.argv, ['Arg0']
         prev1, sys.argv = sys.argv, ['Arg0']
         try:
         try:
             st = worker.set_process_status('Running')
             st = worker.set_process_status('Running')
@@ -452,7 +454,7 @@ class test_funs(WorkerAppCase):
     @disable_stdouts
     @disable_stdouts
     def test_parse_options(self):
     def test_parse_options(self):
         cmd = worker()
         cmd = worker()
-        cmd.app = current_app
+        cmd.app = self.app
         opts, args = cmd.parse_options('worker', ['--concurrency=512'])
         opts, args = cmd.parse_options('worker', ['--concurrency=512'])
         self.assertEqual(opts.concurrency, 512)
         self.assertEqual(opts.concurrency, 512)
 
 

+ 2 - 18
celery/tests/case.py

@@ -30,7 +30,6 @@ from nose import SkipTest
 from kombu.log import NullHandler
 from kombu.log import NullHandler
 from kombu.utils import nested
 from kombu.utils import nested
 
 
-from celery.app import app_or_default
 from celery.five import (
 from celery.five import (
     WhateverIO, builtins, items, reraise,
     WhateverIO, builtins, items, reraise,
     string_t, values, open_fqdn,
     string_t, values, open_fqdn,
@@ -236,9 +235,7 @@ def wrap_logger(logger, loglevel=logging.ERROR):
 
 
 
 
 @contextmanager
 @contextmanager
-def eager_tasks():
-    app = app_or_default()
-
+def eager_tasks(app):
     prev = app.conf.CELERY_ALWAYS_EAGER
     prev = app.conf.CELERY_ALWAYS_EAGER
     app.conf.CELERY_ALWAYS_EAGER = True
     app.conf.CELERY_ALWAYS_EAGER = True
     try:
     try:
@@ -247,19 +244,6 @@ def eager_tasks():
         app.conf.CELERY_ALWAYS_EAGER = prev
         app.conf.CELERY_ALWAYS_EAGER = prev
 
 
 
 
-def with_eager_tasks(fun):
-
-    @wraps(fun)
-    def _inner(*args, **kwargs):
-        app = app_or_default()
-        prev = app.conf.CELERY_ALWAYS_EAGER
-        app.conf.CELERY_ALWAYS_EAGER = True
-        try:
-            return fun(*args, **kwargs)
-        finally:
-            app.conf.CELERY_ALWAYS_EAGER = prev
-
-
 def with_environ(env_name, env_value):
 def with_environ(env_name, env_value):
 
 
     def _envpatched(fun):
     def _envpatched(fun):
@@ -547,7 +531,7 @@ def patch_many(*targets):
 
 
 
 
 @contextmanager
 @contextmanager
-def patch_settings(app=None, **config):
+def patch_settings(app, **config):
     if app is None:
     if app is None:
         from celery import current_app
         from celery import current_app
         app = current_app
         app = current_app

+ 13 - 12
celery/tests/compat_modules/test_sets.py

@@ -4,12 +4,11 @@ import anyjson
 
 
 from mock import Mock, patch
 from mock import Mock, patch
 
 
-from celery import current_app
 from celery.task import Task
 from celery.task import Task
 from celery.task.sets import subtask, TaskSet
 from celery.task.sets import subtask, TaskSet
 from celery.canvas import Signature
 from celery.canvas import Signature
 
 
-from celery.tests.case import Case
+from celery.tests.case import AppCase
 
 
 
 
 class MockTask(Task):
 class MockTask(Task):
@@ -27,7 +26,7 @@ class MockTask(Task):
         return (args, kwargs, options)
         return (args, kwargs, options)
 
 
 
 
-class test_subtask(Case):
+class test_subtask(AppCase):
 
 
     def test_behaves_like_type(self):
     def test_behaves_like_type(self):
         s = subtask('tasks.add', (2, 2), {'cache': True},
         s = subtask('tasks.add', (2, 2), {'cache': True},
@@ -104,15 +103,15 @@ class test_subtask(Case):
         self.assertDictEqual(dict(cls(*args)), dict(s))
         self.assertDictEqual(dict(cls(*args)), dict(s))
 
 
 
 
-class test_TaskSet(Case):
+class test_TaskSet(AppCase):
 
 
     def test_task_arg_can_be_iterable__compat(self):
     def test_task_arg_can_be_iterable__compat(self):
         ts = TaskSet([MockTask.subtask((i, i))
         ts = TaskSet([MockTask.subtask((i, i))
-                      for i in (2, 4, 8)])
+                      for i in (2, 4, 8)], app=self.app)
         self.assertEqual(len(ts), 3)
         self.assertEqual(len(ts), 3)
 
 
     def test_respects_ALWAYS_EAGER(self):
     def test_respects_ALWAYS_EAGER(self):
-        app = current_app
+        app = self.app
 
 
         class MockTaskSet(TaskSet):
         class MockTaskSet(TaskSet):
             applied = 0
             applied = 0
@@ -122,6 +121,7 @@ class test_TaskSet(Case):
 
 
         ts = MockTaskSet(
         ts = MockTaskSet(
             [MockTask.subtask((i, i)) for i in (2, 4, 8)],
             [MockTask.subtask((i, i)) for i in (2, 4, 8)],
+            app=self.app,
         )
         )
         app.conf.CELERY_ALWAYS_EAGER = True
         app.conf.CELERY_ALWAYS_EAGER = True
         try:
         try:
@@ -145,7 +145,7 @@ class test_TaskSet(Case):
                 applied[0] += 1
                 applied[0] += 1
 
 
         ts = TaskSet([mocksubtask(MockTask, (i, i))
         ts = TaskSet([mocksubtask(MockTask, (i, i))
-                      for i in (2, 4, 8)])
+                      for i in (2, 4, 8)], app=self.app)
         ts.apply_async()
         ts.apply_async()
         self.assertEqual(applied[0], 3)
         self.assertEqual(applied[0], 3)
 
 
@@ -158,9 +158,10 @@ class test_TaskSet(Case):
 
 
         # setting current_task
         # setting current_task
 
 
-        @current_app.task
+        @self.app.task
         def xyz():
         def xyz():
             pass
             pass
+
         from celery._state import _task_stack
         from celery._state import _task_stack
         xyz.push_request()
         xyz.push_request()
         _task_stack.push(xyz)
         _task_stack.push(xyz)
@@ -180,21 +181,21 @@ class test_TaskSet(Case):
                 applied[0] += 1
                 applied[0] += 1
 
 
         ts = TaskSet([mocksubtask(MockTask, (i, i))
         ts = TaskSet([mocksubtask(MockTask, (i, i))
-                      for i in (2, 4, 8)])
+                      for i in (2, 4, 8)], app=self.app)
         ts.apply()
         ts.apply()
         self.assertEqual(applied[0], 3)
         self.assertEqual(applied[0], 3)
 
 
     def test_set_app(self):
     def test_set_app(self):
-        ts = TaskSet([])
+        ts = TaskSet([], app=self.app)
         ts.app = 42
         ts.app = 42
         self.assertEqual(ts.app, 42)
         self.assertEqual(ts.app, 42)
 
 
     def test_set_tasks(self):
     def test_set_tasks(self):
-        ts = TaskSet([])
+        ts = TaskSet([], app=self.app)
         ts.tasks = [1, 2, 3]
         ts.tasks = [1, 2, 3]
         self.assertEqual(ts, [1, 2, 3])
         self.assertEqual(ts, [1, 2, 3])
 
 
     def test_set_Publisher(self):
     def test_set_Publisher(self):
-        ts = TaskSet([])
+        ts = TaskSet([], app=self.app)
         ts.Publisher = 42
         ts.Publisher = 42
         self.assertEqual(ts.Publisher, 42)
         self.assertEqual(ts.Publisher, 42)

+ 4 - 4
celery/tests/events/test_cursesmon.py

@@ -2,7 +2,7 @@ from __future__ import absolute_import
 
 
 from nose import SkipTest
 from nose import SkipTest
 
 
-from celery.tests.case import Case
+from celery.tests.case import AppCase
 
 
 
 
 class MockWindow(object):
 class MockWindow(object):
@@ -11,16 +11,16 @@ class MockWindow(object):
         return self.y, self.x
         return self.y, self.x
 
 
 
 
-class test_CursesDisplay(Case):
+class test_CursesDisplay(AppCase):
 
 
-    def setUp(self):
+    def setup(self):
         try:
         try:
             import curses  # noqa
             import curses  # noqa
         except ImportError:
         except ImportError:
             raise SkipTest('curses monitor requires curses')
             raise SkipTest('curses monitor requires curses')
 
 
         from celery.events import cursesmon
         from celery.events import cursesmon
-        self.monitor = cursesmon.CursesMonitor(object())
+        self.monitor = cursesmon.CursesMonitor(object(), app=self.app)
         self.win = MockWindow()
         self.win = MockWindow()
         self.monitor.win = self.win
         self.monitor.win = self.win
 
 

+ 6 - 9
celery/tests/events/test_snapshot.py

@@ -2,10 +2,9 @@ from __future__ import absolute_import
 
 
 from mock import patch
 from mock import patch
 
 
-from celery.app import app_or_default
 from celery.events import Events
 from celery.events import Events
 from celery.events.snapshot import Polaroid, evcam
 from celery.events.snapshot import Polaroid, evcam
-from celery.tests.case import Case
+from celery.tests.case import AppCase
 
 
 
 
 class TRef(object):
 class TRef(object):
@@ -28,10 +27,9 @@ class MockTimer(object):
 timer = MockTimer()
 timer = MockTimer()
 
 
 
 
-class test_Polaroid(Case):
+class test_Polaroid(AppCase):
 
 
-    def setUp(self):
-        self.app = app_or_default()
+    def setup(self):
         self.state = self.app.events.State()
         self.state = self.app.events.State()
 
 
     def test_constructor(self):
     def test_constructor(self):
@@ -99,7 +97,7 @@ class test_Polaroid(Case):
         self.assertEqual(shutter_signal_sent[0], 1)
         self.assertEqual(shutter_signal_sent[0], 1)
 
 
 
 
-class test_evcam(Case):
+class test_evcam(AppCase):
 
 
     class MockReceiver(object):
     class MockReceiver(object):
         raise_keyboard_interrupt = False
         raise_keyboard_interrupt = False
@@ -113,12 +111,11 @@ class test_evcam(Case):
         def Receiver(self, *args, **kwargs):
         def Receiver(self, *args, **kwargs):
             return test_evcam.MockReceiver()
             return test_evcam.MockReceiver()
 
 
-    def setUp(self):
-        self.app = app_or_default()
+    def setup(self):
         self.prev, self.app.events = self.app.events, self.MockEvents()
         self.prev, self.app.events = self.app.events, self.MockEvents()
         self.app.events.app = self.app
         self.app.events.app = self.app
 
 
-    def tearDown(self):
+    def teardown(self):
         self.app.events = self.prev
         self.app.events = self.prev
 
 
     def test_evcam(self):
     def test_evcam(self):

+ 3 - 3
celery/tests/security/case.py

@@ -2,12 +2,12 @@ from __future__ import absolute_import
 
 
 from nose import SkipTest
 from nose import SkipTest
 
 
-from celery.tests.case import Case
+from celery.tests.case import AppCase
 
 
 
 
-class SecurityCase(Case):
+class SecurityCase(AppCase):
 
 
-    def setUp(self):
+    def setup(self):
         try:
         try:
             from OpenSSL import crypto  # noqa
             from OpenSSL import crypto  # noqa
         except ImportError:
         except ImportError:

+ 40 - 28
celery/tests/security/test_security.py

@@ -18,10 +18,9 @@ from __future__ import absolute_import
 
 
 from mock import Mock, patch
 from mock import Mock, patch
 
 
-from celery import current_app
 from celery.exceptions import ImproperlyConfigured, SecurityError
 from celery.exceptions import ImproperlyConfigured, SecurityError
 from celery.five import builtins
 from celery.five import builtins
-from celery.security import setup_security, disable_untrusted_serializers
+from celery.security import disable_untrusted_serializers
 from celery.security.utils import reraise_errors
 from celery.security.utils import reraise_errors
 from kombu.serialization import registry
 from kombu.serialization import registry
 
 
@@ -55,11 +54,14 @@ class test_security(SecurityCase):
         disabled = registry._disabled_content_types
         disabled = registry._disabled_content_types
         self.assertEqual(0, len(disabled))
         self.assertEqual(0, len(disabled))
 
 
-        current_app.conf.CELERY_TASK_SERIALIZER = 'json'
-
-        setup_security()
-        self.assertIn('application/x-python-serialize', disabled)
-        disabled.clear()
+        prev, self.app.conf.CELERY_TASK_SERIALIZER = (
+            self.app.conf.CELERY_TASK_SERIALIZER, 'json')
+        try:
+            self.app.setup_security()
+            self.assertIn('application/x-python-serialize', disabled)
+            disabled.clear()
+        finally:
+            self.app.conf.CELERY_TASK_SERIALIZER = prev
 
 
     @patch('celery.security.register_auth')
     @patch('celery.security.register_auth')
     @patch('celery.security.disable_untrusted_serializers')
     @patch('celery.security.disable_untrusted_serializers')
@@ -74,29 +76,39 @@ class test_security(SecurityCase):
             finally:
             finally:
                 calls[0] += 1
                 calls[0] += 1
 
 
-        with mock_open(side_effect=effect):
-            with patch('celery.security.registry') as registry:
-                store = Mock()
-                setup_security(['json'], key, cert, store)
-                dis.assert_called_with(['json'])
-                reg.assert_called_with('A', 'B', store, 'sha1', 'json')
-                registry._set_default_serializer.assert_called_with('auth')
+        prev, self.app.conf.CELERY_TASK_SERIALIZER = (
+                self.app.conf.CELERY_TASK_SERIALIZER, 'auth')
+        try:
+            with mock_open(side_effect=effect):
+                with patch('celery.security.registry') as registry:
+                    store = Mock()
+                    self.app.setup_security(['json'], key, cert, store)
+                    dis.assert_called_with(['json'])
+                    reg.assert_called_with('A', 'B', store, 'sha1', 'json')
+                    registry._set_default_serializer.assert_called_with('auth')
+        finally:
+            self.app.conf.CELERY_TASK_SERIALIZER = prev
 
 
     def test_security_conf(self):
     def test_security_conf(self):
-        current_app.conf.CELERY_TASK_SERIALIZER = 'auth'
-
-        self.assertRaises(ImproperlyConfigured, setup_security)
-
-        _import = builtins.__import__
-
-        def import_hook(name, *args, **kwargs):
-            if name == 'OpenSSL':
-                raise ImportError
-            return _import(name, *args, **kwargs)
-
-        builtins.__import__ = import_hook
-        self.assertRaises(ImproperlyConfigured, setup_security)
-        builtins.__import__ = _import
+        prev, self.app.conf.CELERY_TASK_SERIALIZER = (
+            self.app.conf.CELERY_TASK_SERIALIZER, 'auth')
+        try:
+            with self.assertRaises(ImproperlyConfigured):
+                self.app.setup_security()
+
+            _import = builtins.__import__
+
+            def import_hook(name, *args, **kwargs):
+                if name == 'OpenSSL':
+                    raise ImportError
+                return _import(name, *args, **kwargs)
+
+            builtins.__import__ = import_hook
+            with self.assertRaises(ImproperlyConfigured):
+                self.app.setup_security()
+            builtins.__import__ = _import
+        finally:
+            self.app.conf.CELERY_TASK_SERIALIZER = prev
 
 
     def test_reraise_errors(self):
     def test_reraise_errors(self):
         with self.assertRaises(SecurityError):
         with self.assertRaises(SecurityError):

+ 27 - 28
celery/tests/tasks/test_chord.py

@@ -3,26 +3,25 @@ from __future__ import absolute_import
 from mock import patch
 from mock import patch
 from contextlib import contextmanager
 from contextlib import contextmanager
 
 
+from celery import group
 from celery import canvas
 from celery import canvas
-from celery import current_app
 from celery import result
 from celery import result
 from celery.exceptions import ChordError
 from celery.exceptions import ChordError
 from celery.five import range
 from celery.five import range
 from celery.result import AsyncResult, GroupResult, EagerResult
 from celery.result import AsyncResult, GroupResult, EagerResult
-from celery.task import task, TaskSet
 from celery.tests.case import AppCase, Mock
 from celery.tests.case import AppCase, Mock
 
 
 passthru = lambda x: x
 passthru = lambda x: x
 
 
 
 
-@current_app.task
-def add(x, y):
-    return x + y
+class ChordCase(AppCase):
 
 
+    def setup(self):
 
 
-@current_app.task
-def callback(r):
-    return r
+        @self.app.task
+        def add(x, y):
+            return x + y
+        self.add = add
 
 
 
 
 class TSR(GroupResult):
 class TSR(GroupResult):
@@ -53,8 +52,8 @@ class TSRNoReport(TSR):
 
 
 
 
 @contextmanager
 @contextmanager
-def patch_unlock_retry():
-    unlock = current_app.tasks['celery.chord_unlock']
+def patch_unlock_retry(app):
+    unlock = app.tasks['celery.chord_unlock']
     retry = Mock()
     retry = Mock()
     prev, unlock.retry = unlock.retry, retry
     prev, unlock.retry = unlock.retry, retry
     try:
     try:
@@ -63,7 +62,7 @@ def patch_unlock_retry():
         unlock.retry = prev
         unlock.retry = prev
 
 
 
 
-class test_unlock_chord_task(AppCase):
+class test_unlock_chord_task(ChordCase):
 
 
     @patch('celery.result.GroupResult')
     @patch('celery.result.GroupResult')
     def test_unlock_ready(self, GroupResult):
     def test_unlock_ready(self, GroupResult):
@@ -133,7 +132,7 @@ class test_unlock_chord_task(AppCase):
     def _chord_context(self, ResultCls, setup=None, **kwargs):
     def _chord_context(self, ResultCls, setup=None, **kwargs):
         with patch('celery.result.GroupResult'):
         with patch('celery.result.GroupResult'):
 
 
-            @task()
+            @self.app.task()
             def callback(*args, **kwargs):
             def callback(*args, **kwargs):
                 pass
                 pass
 
 
@@ -143,7 +142,7 @@ class test_unlock_chord_task(AppCase):
             callback_s.id = 'callback_id'
             callback_s.id = 'callback_id'
             fail_current = self.app.backend.fail_from_current_stack = Mock()
             fail_current = self.app.backend.fail_from_current_stack = Mock()
             try:
             try:
-                with patch_unlock_retry() as (unlock, retry):
+                with patch_unlock_retry(self.app) as (unlock, retry):
                     subtask, canvas.maybe_subtask = (
                     subtask, canvas.maybe_subtask = (
                         canvas.maybe_subtask, passthru,
                         canvas.maybe_subtask, passthru,
                     )
                     )
@@ -173,19 +172,19 @@ class test_unlock_chord_task(AppCase):
             retry.assert_called_with(countdown=10, max_retries=30)
             retry.assert_called_with(countdown=10, max_retries=30)
 
 
     def test_is_in_registry(self):
     def test_is_in_registry(self):
-        self.assertIn('celery.chord_unlock', current_app.tasks)
+        self.assertIn('celery.chord_unlock', self.app.tasks)
 
 
 
 
-class test_chord(AppCase):
+class test_chord(ChordCase):
 
 
     def test_eager(self):
     def test_eager(self):
         from celery import chord
         from celery import chord
 
 
-        @task()
+        @self.app.task()
         def addX(x, y):
         def addX(x, y):
             return x + y
             return x + y
 
 
-        @task()
+        @self.app.task()
         def sumX(n):
         def sumX(n):
             return sum(n)
             return sum(n)
 
 
@@ -207,8 +206,8 @@ class test_chord(AppCase):
         m.AsyncResult = AsyncResult
         m.AsyncResult = AsyncResult
         prev, chord._type = chord._type, m
         prev, chord._type = chord._type, m
         try:
         try:
-            x = chord(add.s(i, i) for i in range(10))
-            body = add.s(2)
+            x = chord(self.add.s(i, i) for i in range(10))
+            body = self.add.s(2)
             result = x(body)
             result = x(body)
             self.assertTrue(result.id)
             self.assertTrue(result.id)
             # does not modify original subtask
             # does not modify original subtask
@@ -219,18 +218,18 @@ class test_chord(AppCase):
             chord._type = prev
             chord._type = prev
 
 
 
 
-class test_Chord_task(AppCase):
+class test_Chord_task(ChordCase):
 
 
     def test_run(self):
     def test_run(self):
-        prev, current_app.backend = current_app.backend, Mock()
-        current_app.backend.cleanup = Mock()
-        current_app.backend.cleanup.__name__ = 'cleanup'
+        prev, self.app.backend = self.app.backend, Mock()
+        self.app.backend.cleanup = Mock()
+        self.app.backend.cleanup.__name__ = 'cleanup'
         try:
         try:
-            Chord = current_app.tasks['celery.chord']
+            Chord = self.app.tasks['celery.chord']
 
 
             body = dict()
             body = dict()
-            Chord(TaskSet(add.subtask((i, i)) for i in range(5)), body)
-            Chord([add.subtask((j, j)) for j in range(5)], body)
-            self.assertEqual(current_app.backend.on_chord_apply.call_count, 2)
+            Chord(group(self.add.subtask((i, i)) for i in range(5)), body)
+            Chord([self.add.subtask((j, j)) for j in range(5)], body)
+            self.assertEqual(self.app.backend.on_chord_apply.call_count, 2)
         finally:
         finally:
-            current_app.backend = prev
+            self.app.backend = prev

+ 5 - 5
celery/tests/tasks/test_http.py

@@ -13,7 +13,7 @@ from kombu.utils.encoding import from_utf8
 
 
 from celery.five import StringIO, items
 from celery.five import StringIO, items
 from celery.task import http
 from celery.task import http
-from celery.tests.case import Case, eager_tasks
+from celery.tests.case import AppCase, Case, eager_tasks
 
 
 
 
 @contextmanager
 @contextmanager
@@ -96,7 +96,7 @@ class test_MutableURL(Case):
         self.assertEqual(url.query, {'zzz': 'xxx'})
         self.assertEqual(url.query, {'zzz': 'xxx'})
 
 
 
 
-class test_HttpDispatch(Case):
+class test_HttpDispatch(AppCase):
 
 
     def test_dispatch_success(self):
     def test_dispatch_success(self):
         with mock_urlopen(success_response(100)):
         with mock_urlopen(success_response(100)):
@@ -139,16 +139,16 @@ class test_HttpDispatch(Case):
             self.assertEqual(d.dispatch(), 100)
             self.assertEqual(d.dispatch(), 100)
 
 
 
 
-class test_URL(Case):
+class test_URL(AppCase):
 
 
     def test_URL_get_async(self):
     def test_URL_get_async(self):
-        with eager_tasks():
+        with eager_tasks(self.app):
             with mock_urlopen(success_response(100)):
             with mock_urlopen(success_response(100)):
                 d = http.URL('http://example.com/mul').get_async(x=10, y=10)
                 d = http.URL('http://example.com/mul').get_async(x=10, y=10)
                 self.assertEqual(d.get(), 100)
                 self.assertEqual(d.get(), 100)
 
 
     def test_URL_post_async(self):
     def test_URL_post_async(self):
-        with eager_tasks():
+        with eager_tasks(self.app):
             with mock_urlopen(success_response(100)):
             with mock_urlopen(success_response(100)):
                 d = http.URL('http://example.com/mul').post_async(x=10, y=10)
                 d = http.URL('http://example.com/mul').post_async(x=10, y=10)
                 self.assertEqual(d.get(), 100)
                 self.assertEqual(d.get(), 100)

+ 15 - 14
celery/tests/tasks/test_result.py

@@ -23,11 +23,6 @@ from celery.tests.case import AppCase
 from celery.tests.case import skip_if_quick
 from celery.tests.case import skip_if_quick
 
 
 
 
-@task()
-def mytask():
-    pass
-
-
 def mock_task(name, state, result):
 def mock_task(name, state, result):
     return dict(id=uuid(), name=name, state=state, result=result)
     return dict(id=uuid(), name=name, state=state, result=result)
 
 
@@ -63,6 +58,11 @@ class test_AsyncResult(AppCase):
         for task in (self.task1, self.task2, self.task3, self.task4):
         for task in (self.task1, self.task2, self.task3, self.task4):
             save_result(self.app, task)
             save_result(self.app, task)
 
 
+        @self.app.task()
+        def mytask():
+            pass
+        self.mytask = mytask
+
     def test_compat_properties(self):
     def test_compat_properties(self):
         x = self.app.AsyncResult('1')
         x = self.app.AsyncResult('1')
         self.assertEqual(x.task_id, x.id)
         self.assertEqual(x.task_id, x.id)
@@ -153,10 +153,10 @@ class test_AsyncResult(AppCase):
         self.assertFalse(self.app.AsyncResult('1') == object())
         self.assertFalse(self.app.AsyncResult('1') == object())
 
 
     def test_reduce(self):
     def test_reduce(self):
-        a1 = self.app.AsyncResult('uuid', task_name=mytask.name)
+        a1 = self.app.AsyncResult('uuid', task_name=self.mytask.name)
         restored = pickle.loads(pickle.dumps(a1))
         restored = pickle.loads(pickle.dumps(a1))
         self.assertEqual(restored.id, 'uuid')
         self.assertEqual(restored.id, 'uuid')
-        self.assertEqual(restored.task_name, mytask.name)
+        self.assertEqual(restored.task_name, self.mytask.name)
 
 
         a2 = self.app.AsyncResult('uuid')
         a2 = self.app.AsyncResult('uuid')
         self.assertEqual(pickle.loads(pickle.dumps(a2)).id, 'uuid')
         self.assertEqual(pickle.loads(pickle.dumps(a2)).id, 'uuid')
@@ -658,16 +658,17 @@ class test_pending_Group(AppCase):
             self.ts.join(timeout=1)
             self.ts.join(timeout=1)
 
 
 
 
-class RaisingTask(Task):
-
-    def run(self, x, y):
-        raise KeyError('xy')
+class test_EagerResult(AppCase):
 
 
+    def setup(self):
 
 
-class test_EagerResult(AppCase):
+        @self.app.task
+        def raising(x, y):
+            raise KeyError(x, y)
+        self.raising = raising
 
 
     def test_wait_raises(self):
     def test_wait_raises(self):
-        res = RaisingTask.apply(args=[3, 3])
+        res = self.raising.apply(args=[3, 3])
         with self.assertRaises(KeyError):
         with self.assertRaises(KeyError):
             res.wait()
             res.wait()
         self.assertTrue(res.wait(propagate=False))
         self.assertTrue(res.wait(propagate=False))
@@ -683,7 +684,7 @@ class test_EagerResult(AppCase):
         res.forget()
         res.forget()
 
 
     def test_revoke(self):
     def test_revoke(self):
-        res = RaisingTask.apply(args=[3, 3])
+        res = self.raising.apply(args=[3, 3])
         self.assertFalse(res.revoke())
         self.assertFalse(res.revoke())
 
 
 
 

+ 33 - 29
celery/tests/tasks/test_tasks.py

@@ -17,8 +17,6 @@ from celery.task import (
     periodic_task,
     periodic_task,
     PeriodicTask
     PeriodicTask
 )
 )
-from celery import current_app
-from celery.app import app_or_default
 from celery.exceptions import RetryTaskError
 from celery.exceptions import RetryTaskError
 from celery.execute import send_task
 from celery.execute import send_task
 from celery.five import items, range, string_t
 from celery.five import items, range, string_t
@@ -27,11 +25,7 @@ from celery.schedules import crontab, crontab_parser, ParseException
 from celery.utils import uuid
 from celery.utils import uuid
 from celery.utils.timeutils import parse_iso8601, timedelta_seconds
 from celery.utils.timeutils import parse_iso8601, timedelta_seconds
 
 
-from celery.tests.case import Case, with_eager_tasks, WhateverIO
-
-
-def now():
-    return current_app.now()
+from celery.tests.case import AppCase, eager_tasks, WhateverIO
 
 
 
 
 def return_True(*args, **kwargs):
 def return_True(*args, **kwargs):
@@ -124,7 +118,7 @@ def retry_task_customexc(arg1, arg2, kwarg=1, **kwargs):
             raise current.retry(countdown=0, exc=exc)
             raise current.retry(countdown=0, exc=exc)
 
 
 
 
-class test_task_retries(Case):
+class test_task_retries(AppCase):
 
 
     def test_retry(self):
     def test_retry(self):
         retry_task.__class__.max_retries = 3
         retry_task.__class__.max_retries = 3
@@ -207,7 +201,7 @@ class test_task_retries(Case):
         self.assertEqual(retry_task.iterations, 2)
         self.assertEqual(retry_task.iterations, 2)
 
 
 
 
-class test_canvas_utils(Case):
+class test_canvas_utils(AppCase):
 
 
     def test_si(self):
     def test_si(self):
         self.assertTrue(retry_task.si())
         self.assertTrue(retry_task.si())
@@ -226,7 +220,10 @@ class test_canvas_utils(Case):
         retry_task.on_success(1, 1, (), {})
         retry_task.on_success(1, 1, (), {})
 
 
 
 
-class test_tasks(Case):
+class test_tasks(AppCase):
+
+    def now(self):
+        return self.app.now()
 
 
     def test_unpickle_task(self):
     def test_unpickle_task(self):
         import pickle
         import pickle
@@ -312,8 +309,8 @@ class test_tasks(Case):
         # With eta.
         # With eta.
         presult2 = T1.apply_async(
         presult2 = T1.apply_async(
             kwargs=dict(name='George Costanza'),
             kwargs=dict(name='George Costanza'),
-            eta=now() + timedelta(days=1),
-            expires=now() + timedelta(days=2),
+            eta=self.now() + timedelta(days=1),
+            expires=self.now() + timedelta(days=2),
         )
         )
         self.assertNextTaskDataEqual(
         self.assertNextTaskDataEqual(
             consumer, presult2, T1.name,
             consumer, presult2, T1.name,
@@ -411,7 +408,7 @@ class test_tasks(Case):
                 del(app.amqp.__dict__['TaskProducer'])
                 del(app.amqp.__dict__['TaskProducer'])
 
 
     def test_get_publisher(self):
     def test_get_publisher(self):
-        connection = app_or_default().connection()
+        connection = self.app.connection()
         p = increment_counter.get_publisher(connection, auto_declare=False,
         p = increment_counter.get_publisher(connection, auto_declare=False,
                                             exchange='foo')
                                             exchange='foo')
         self.assertEqual(p.exchange.name, 'foo')
         self.assertEqual(p.exchange.name, 'foo')
@@ -471,14 +468,14 @@ class test_tasks(Case):
             t1.pop_request()
             t1.pop_request()
 
 
 
 
-class test_TaskSet(Case):
+class test_TaskSet(AppCase):
 
 
-    @with_eager_tasks
     def test_function_taskset(self):
     def test_function_taskset(self):
-        subtasks = [return_True_task.s(i) for i in range(1, 6)]
-        ts = TaskSet(subtasks)
-        res = ts.apply_async()
-        self.assertListEqual(res.join(), [True, True, True, True, True])
+        with eager_tasks(self.app):
+            subtasks = [return_True_task.s(i) for i in range(1, 6)]
+            ts = TaskSet(subtasks)
+            res = ts.apply_async()
+            self.assertListEqual(res.join(), [True, True, True, True, True])
 
 
     def test_counter_taskset(self):
     def test_counter_taskset(self):
         increment_counter.count = 0
         increment_counter.count = 0
@@ -518,7 +515,7 @@ class test_TaskSet(Case):
         self.assertTrue(res.taskset_id.startswith(prefix))
         self.assertTrue(res.taskset_id.startswith(prefix))
 
 
 
 
-class test_apply_task(Case):
+class test_apply_task(AppCase):
 
 
     def test_apply_throw(self):
     def test_apply_throw(self):
         with self.assertRaises(KeyError):
         with self.assertRaises(KeyError):
@@ -569,7 +566,10 @@ def my_periodic():
     pass
     pass
 
 
 
 
-class test_periodic_tasks(Case):
+class test_periodic_tasks(AppCase):
+
+    def now(self):
+        return self.app.now()
 
 
     def test_must_have_run_every(self):
     def test_must_have_run_every(self):
         with self.assertRaises(NotImplementedError):
         with self.assertRaises(NotImplementedError):
@@ -578,11 +578,11 @@ class test_periodic_tasks(Case):
     def test_remaining_estimate(self):
     def test_remaining_estimate(self):
         s = my_periodic.run_every
         s = my_periodic.run_every
         self.assertIsInstance(
         self.assertIsInstance(
-            s.remaining_estimate(s.maybe_make_aware(now())),
+            s.remaining_estimate(s.maybe_make_aware(self.now())),
             timedelta)
             timedelta)
 
 
     def test_is_due_not_due(self):
     def test_is_due_not_due(self):
-        due, remaining = my_periodic.run_every.is_due(now())
+        due, remaining = my_periodic.run_every.is_due(self.now())
         self.assertFalse(due)
         self.assertFalse(due)
         # This assertion may fail if executed in the
         # This assertion may fail if executed in the
         # first minute of an hour, thus 59 instead of 60
         # first minute of an hour, thus 59 instead of 60
@@ -591,7 +591,8 @@ class test_periodic_tasks(Case):
     def test_is_due(self):
     def test_is_due(self):
         p = my_periodic
         p = my_periodic
         due, remaining = p.run_every.is_due(
         due, remaining = p.run_every.is_due(
-            now() - p.run_every.run_every)
+            self.now() - p.run_every.run_every,
+        )
         self.assertTrue(due)
         self.assertTrue(due)
         self.assertEqual(remaining,
         self.assertEqual(remaining,
                          timedelta_seconds(p.run_every.run_every))
                          timedelta_seconds(p.run_every.run_every))
@@ -660,7 +661,7 @@ def patch_crontab_nowfun(cls, retval):
     return create_patcher
     return create_patcher
 
 
 
 
-class test_crontab_parser(Case):
+class test_crontab_parser(AppCase):
 
 
     def test_crontab_reduce(self):
     def test_crontab_reduce(self):
         self.assertTrue(loads(dumps(crontab('*'))))
         self.assertTrue(loads(dumps(crontab('*'))))
@@ -810,7 +811,7 @@ class test_crontab_parser(Case):
         self.assertFalse(crontab(minute='1') == object())
         self.assertFalse(crontab(minute='1') == object())
 
 
 
 
-class test_crontab_remaining_estimate(Case):
+class test_crontab_remaining_estimate(AppCase):
 
 
     def next_ocurrance(self, crontab, now):
     def next_ocurrance(self, crontab, now):
         crontab.nowfun = lambda: now
         crontab.nowfun = lambda: now
@@ -982,10 +983,13 @@ class test_crontab_remaining_estimate(Case):
         self.assertEqual(next, datetime(2010, 5, 29, 0, 5))
         self.assertEqual(next, datetime(2010, 5, 29, 0, 5))
 
 
 
 
-class test_crontab_is_due(Case):
+class test_crontab_is_due(AppCase):
+
+    def getnow(self):
+        return self.app.now()
 
 
-    def setUp(self):
-        self.now = now()
+    def setup(self):
+        self.now = self.getnow()
         self.next_minute = 60 - self.now.second - 1e-6 * self.now.microsecond
         self.next_minute = 60 - self.now.second - 1e-6 * self.now.microsecond
 
 
     def test_default_crontab_spec(self):
     def test_default_crontab_spec(self):

+ 4 - 3
celery/tests/worker/test_control.py

@@ -249,7 +249,7 @@ class test_ControlPanel(AppCase):
         self.panel.handle('report')
         self.panel.handle('report')
 
 
     def test_active(self):
     def test_active(self):
-        r = TaskRequest(mytask.name, 'do re mi', (), {})
+        r = TaskRequest(mytask.name, 'do re mi', (), {}, app=self.app)
         worker_state.active_requests.add(r)
         worker_state.active_requests.add(r)
         try:
         try:
             self.assertTrue(self.panel.handle('dump_active'))
             self.assertTrue(self.panel.handle('dump_active'))
@@ -331,7 +331,7 @@ class test_ControlPanel(AppCase):
         consumer = Consumer(self.app)
         consumer = Consumer(self.app)
         panel = self.create_panel(consumer=consumer)
         panel = self.create_panel(consumer=consumer)
         self.assertFalse(panel.handle('dump_schedule'))
         self.assertFalse(panel.handle('dump_schedule'))
-        r = TaskRequest(mytask.name, 'CAFEBABE', (), {})
+        r = TaskRequest(mytask.name, 'CAFEBABE', (), {}, app=self.app)
         consumer.timer.schedule.enter(
         consumer.timer.schedule.enter(
             consumer.timer.Entry(lambda x: x, (r, )),
             consumer.timer.Entry(lambda x: x, (r, )),
             datetime.now() + timedelta(seconds=10))
             datetime.now() + timedelta(seconds=10))
@@ -343,7 +343,8 @@ class test_ControlPanel(AppCase):
     def test_dump_reserved(self):
     def test_dump_reserved(self):
         consumer = Consumer(self.app)
         consumer = Consumer(self.app)
         worker_state.reserved_requests.add(
         worker_state.reserved_requests.add(
-            TaskRequest(mytask.name, uuid(), args=(2, 2), kwargs={}),
+            TaskRequest(mytask.name, uuid(), args=(2, 2), kwargs={},
+                        app=self.app),
         )
         )
         try:
         try:
             panel = self.create_panel(consumer=consumer)
             panel = self.create_panel(consumer=consumer)

+ 25 - 25
celery/tests/worker/test_loops.py

@@ -15,7 +15,7 @@ from celery.tests.case import AppCase, body_from_sig
 
 
 class X(object):
 class X(object):
 
 
-    def __init__(self, heartbeat=None, on_task=None):
+    def __init__(self, app, heartbeat=None, on_task=None):
         (
         (
             self.obj,
             self.obj,
             self.connection,
             self.connection,
@@ -46,7 +46,7 @@ class X(object):
         self.hub.fire_timers.return_value = 1.7
         self.hub.fire_timers.return_value = 1.7
         self.Hub = self.hub
         self.Hub = self.hub
         # need this for create_task_handler
         # need this for create_task_handler
-        _consumer = Consumer(Mock(), timer=Mock())
+        _consumer = Consumer(Mock(), timer=Mock(), app=app)
         self.obj.create_task_handler = _consumer.create_task_handler
         self.obj.create_task_handler = _consumer.create_task_handler
         self.on_unknown_message = self.obj.on_unknown_message = Mock(
         self.on_unknown_message = self.obj.on_unknown_message = Mock(
             name='on_unknown_message',
             name='on_unknown_message',
@@ -105,7 +105,7 @@ class test_asynloop(AppCase):
         self.add = add
         self.add = add
 
 
     def test_setup_heartbeat(self):
     def test_setup_heartbeat(self):
-        x = X(heartbeat=10)
+        x = X(self.app, heartbeat=10)
         x.blueprint.state = CLOSE
         x.blueprint.state = CLOSE
         asynloop(*x.args)
         asynloop(*x.args)
         x.consumer.consume.assert_called_with()
         x.consumer.consume.assert_called_with()
@@ -115,7 +115,7 @@ class test_asynloop(AppCase):
         )
         )
 
 
     def task_context(self, sig, **kwargs):
     def task_context(self, sig, **kwargs):
-        x, on_task = get_task_callback(**kwargs)
+        x, on_task = get_task_callback(self.app, **kwargs)
         body = body_from_sig(self.app, sig)
         body = body_from_sig(self.app, sig)
         message = Mock()
         message = Mock()
         strategy = x.obj.strategies[sig.task] = Mock()
         strategy = x.obj.strategies[sig.task] = Mock()
@@ -153,7 +153,7 @@ class test_asynloop(AppCase):
         x.on_invalid_task.assert_called_with(body, msg, exc)
         x.on_invalid_task.assert_called_with(body, msg, exc)
 
 
     def test_should_terminate(self):
     def test_should_terminate(self):
-        x = X()
+        x = X(self.app)
         # XXX why aren't the errors propagated?!?
         # XXX why aren't the errors propagated?!?
         state.should_terminate = True
         state.should_terminate = True
         try:
         try:
@@ -163,7 +163,7 @@ class test_asynloop(AppCase):
             state.should_terminate = False
             state.should_terminate = False
 
 
     def test_should_terminate_hub_close_raises(self):
     def test_should_terminate_hub_close_raises(self):
-        x = X()
+        x = X(self.app)
         # XXX why aren't the errors propagated?!?
         # XXX why aren't the errors propagated?!?
         state.should_terminate = True
         state.should_terminate = True
         x.hub.close.side_effect = MemoryError()
         x.hub.close.side_effect = MemoryError()
@@ -174,7 +174,7 @@ class test_asynloop(AppCase):
             state.should_terminate = False
             state.should_terminate = False
 
 
     def test_should_stop(self):
     def test_should_stop(self):
-        x = X()
+        x = X(self.app)
         state.should_stop = True
         state.should_stop = True
         try:
         try:
             with self.assertRaises(SystemExit):
             with self.assertRaises(SystemExit):
@@ -183,13 +183,13 @@ class test_asynloop(AppCase):
             state.should_stop = False
             state.should_stop = False
 
 
     def test_updates_qos(self):
     def test_updates_qos(self):
-        x = X()
+        x = X(self.app)
         x.qos.prev = 3
         x.qos.prev = 3
         x.qos.value = 3
         x.qos.value = 3
         asynloop(*x.args, sleep=x.closer())
         asynloop(*x.args, sleep=x.closer())
         self.assertFalse(x.qos.update.called)
         self.assertFalse(x.qos.update.called)
 
 
-        x = X()
+        x = X(self.app)
         x.qos.prev = 1
         x.qos.prev = 1
         x.qos.value = 6
         x.qos.value = 6
         asynloop(*x.args, sleep=x.closer())
         asynloop(*x.args, sleep=x.closer())
@@ -198,7 +198,7 @@ class test_asynloop(AppCase):
         x.connection.transport.on_poll_start.assert_called_with()
         x.connection.transport.on_poll_start.assert_called_with()
 
 
     def test_poll_empty(self):
     def test_poll_empty(self):
-        x = X()
+        x = X(self.app)
         x.hub.readers = {6: Mock()}
         x.hub.readers = {6: Mock()}
         x.close_then_error(x.connection.drain_nowait)
         x.close_then_error(x.connection.drain_nowait)
         x.hub.fire_timers.return_value = 33.37
         x.hub.fire_timers.return_value = 33.37
@@ -209,7 +209,7 @@ class test_asynloop(AppCase):
         x.connection.transport.on_poll_empty.assert_called_with()
         x.connection.transport.on_poll_empty.assert_called_with()
 
 
     def test_poll_readable(self):
     def test_poll_readable(self):
-        x = X()
+        x = X(self.app)
         x.hub.readers = {6: Mock()}
         x.hub.readers = {6: Mock()}
         x.close_then_error(x.connection.drain_nowait, mod=4)
         x.close_then_error(x.connection.drain_nowait, mod=4)
         x.hub.poller.poll.return_value = [(6, READ)]
         x.hub.poller.poll.return_value = [(6, READ)]
@@ -219,7 +219,7 @@ class test_asynloop(AppCase):
         self.assertTrue(x.hub.poller.poll.called)
         self.assertTrue(x.hub.poller.poll.called)
 
 
     def test_poll_readable_raises_Empty(self):
     def test_poll_readable_raises_Empty(self):
-        x = X()
+        x = X(self.app)
         x.hub.readers = {6: Mock()}
         x.hub.readers = {6: Mock()}
         x.close_then_error(x.connection.drain_nowait)
         x.close_then_error(x.connection.drain_nowait)
         x.hub.poller.poll.return_value = [(6, READ)]
         x.hub.poller.poll.return_value = [(6, READ)]
@@ -230,7 +230,7 @@ class test_asynloop(AppCase):
         self.assertTrue(x.hub.poller.poll.called)
         self.assertTrue(x.hub.poller.poll.called)
 
 
     def test_poll_writable(self):
     def test_poll_writable(self):
-        x = X()
+        x = X(self.app)
         x.hub.writers = {6: Mock()}
         x.hub.writers = {6: Mock()}
         x.close_then_error(x.connection.drain_nowait)
         x.close_then_error(x.connection.drain_nowait)
         x.hub.poller.poll.return_value = [(6, WRITE)]
         x.hub.poller.poll.return_value = [(6, WRITE)]
@@ -240,7 +240,7 @@ class test_asynloop(AppCase):
         self.assertTrue(x.hub.poller.poll.called)
         self.assertTrue(x.hub.poller.poll.called)
 
 
     def test_poll_writable_none_registered(self):
     def test_poll_writable_none_registered(self):
-        x = X()
+        x = X(self.app)
         x.hub.writers = {6: Mock()}
         x.hub.writers = {6: Mock()}
         x.close_then_error(x.connection.drain_nowait)
         x.close_then_error(x.connection.drain_nowait)
         x.hub.poller.poll.return_value = [(7, WRITE)]
         x.hub.poller.poll.return_value = [(7, WRITE)]
@@ -249,7 +249,7 @@ class test_asynloop(AppCase):
         self.assertTrue(x.hub.poller.poll.called)
         self.assertTrue(x.hub.poller.poll.called)
 
 
     def test_poll_unknown_event(self):
     def test_poll_unknown_event(self):
-        x = X()
+        x = X(self.app)
         x.hub.writers = {6: Mock()}
         x.hub.writers = {6: Mock()}
         x.close_then_error(x.connection.drain_nowait)
         x.close_then_error(x.connection.drain_nowait)
         x.hub.poller.poll.return_value = [(6, 0)]
         x.hub.poller.poll.return_value = [(6, 0)]
@@ -258,7 +258,7 @@ class test_asynloop(AppCase):
         self.assertTrue(x.hub.poller.poll.called)
         self.assertTrue(x.hub.poller.poll.called)
 
 
     def test_poll_keep_draining_disabled(self):
     def test_poll_keep_draining_disabled(self):
-        x = X()
+        x = X(self.app)
         x.hub.writers = {6: Mock()}
         x.hub.writers = {6: Mock()}
         poll = x.hub.poller.poll
         poll = x.hub.poller.poll
 
 
@@ -275,7 +275,7 @@ class test_asynloop(AppCase):
         self.assertFalse(x.connection.drain_nowait.called)
         self.assertFalse(x.connection.drain_nowait.called)
 
 
     def test_poll_err_writable(self):
     def test_poll_err_writable(self):
-        x = X()
+        x = X(self.app)
         x.hub.writers = {6: Mock()}
         x.hub.writers = {6: Mock()}
         x.close_then_error(x.connection.drain_nowait)
         x.close_then_error(x.connection.drain_nowait)
         x.hub.poller.poll.return_value = [(6, ERR)]
         x.hub.poller.poll.return_value = [(6, ERR)]
@@ -285,7 +285,7 @@ class test_asynloop(AppCase):
         self.assertTrue(x.hub.poller.poll.called)
         self.assertTrue(x.hub.poller.poll.called)
 
 
     def test_poll_write_generator(self):
     def test_poll_write_generator(self):
-        x = X()
+        x = X(self.app)
 
 
         def Gen():
         def Gen():
             yield 1
             yield 1
@@ -301,7 +301,7 @@ class test_asynloop(AppCase):
         self.assertFalse(x.hub.remove.called)
         self.assertFalse(x.hub.remove.called)
 
 
     def test_poll_write_generator_stopped(self):
     def test_poll_write_generator_stopped(self):
-        x = X()
+        x = X(self.app)
 
 
         def Gen():
         def Gen():
             raise StopIteration()
             raise StopIteration()
@@ -316,7 +316,7 @@ class test_asynloop(AppCase):
         x.hub.remove.assert_called_with(6)
         x.hub.remove.assert_called_with(6)
 
 
     def test_poll_write_generator_raises(self):
     def test_poll_write_generator_raises(self):
-        x = X()
+        x = X(self.app)
 
 
         def Gen():
         def Gen():
             raise ValueError('foo')
             raise ValueError('foo')
@@ -331,7 +331,7 @@ class test_asynloop(AppCase):
         x.hub.remove.assert_called_with(6)
         x.hub.remove.assert_called_with(6)
 
 
     def test_poll_err_readable(self):
     def test_poll_err_readable(self):
-        x = X()
+        x = X(self.app)
         x.hub.readers = {6: Mock()}
         x.hub.readers = {6: Mock()}
         x.close_then_error(x.connection.drain_nowait)
         x.close_then_error(x.connection.drain_nowait)
         x.hub.poller.poll.return_value = [(6, ERR)]
         x.hub.poller.poll.return_value = [(6, ERR)]
@@ -341,7 +341,7 @@ class test_asynloop(AppCase):
         self.assertTrue(x.hub.poller.poll.called)
         self.assertTrue(x.hub.poller.poll.called)
 
 
     def test_poll_raises_ValueError(self):
     def test_poll_raises_ValueError(self):
-        x = X()
+        x = X(self.app)
         x.hub.readers = {6: Mock()}
         x.hub.readers = {6: Mock()}
         x.close_then_error(x.connection.drain_nowait)
         x.close_then_error(x.connection.drain_nowait)
         x.hub.poller.poll.side_effect = ValueError()
         x.hub.poller.poll.side_effect = ValueError()
@@ -352,14 +352,14 @@ class test_asynloop(AppCase):
 class test_synloop(AppCase):
 class test_synloop(AppCase):
 
 
     def test_timeout_ignored(self):
     def test_timeout_ignored(self):
-        x = X()
+        x = X(self.app)
         x.timeout_then_error(x.connection.drain_events)
         x.timeout_then_error(x.connection.drain_events)
         with self.assertRaises(socket.error):
         with self.assertRaises(socket.error):
             synloop(*x.args)
             synloop(*x.args)
         self.assertEqual(x.connection.drain_events.call_count, 2)
         self.assertEqual(x.connection.drain_events.call_count, 2)
 
 
     def test_updates_qos_when_changed(self):
     def test_updates_qos_when_changed(self):
-        x = X()
+        x = X(self.app)
         x.qos.prev = 2
         x.qos.prev = 2
         x.qos.value = 2
         x.qos.value = 2
         x.timeout_then_error(x.connection.drain_events)
         x.timeout_then_error(x.connection.drain_events)
@@ -374,6 +374,6 @@ class test_synloop(AppCase):
         x.qos.update.assert_called_with()
         x.qos.update.assert_called_with()
 
 
     def test_ignores_socket_errors_when_closed(self):
     def test_ignores_socket_errors_when_closed(self):
-        x = X()
+        x = X(self.app)
         x.close_then_error(x.connection.drain_events)
         x.close_then_error(x.connection.drain_events)
         self.assertIsNone(synloop(*x.args))
         self.assertIsNone(synloop(*x.args))

+ 54 - 46
celery/tests/worker/test_request.py

@@ -203,7 +203,7 @@ class test_trace_task(AppCase):
                 if state == states.STARTED:
                 if state == states.STARTED:
                     self._started.append(tid)
                     self._started.append(tid)
 
 
-        prev, mytask.backend = mytask.backend, Backend()
+        prev, mytask.backend = mytask.backend, Backend(self.app)
         mytask.track_started = True
         mytask.track_started = True
 
 
         try:
         try:
@@ -350,7 +350,7 @@ class test_Request(AppCase):
             self.add.accept_magic_kwargs = False
             self.add.accept_magic_kwargs = False
 
 
     def test_task_wrapper_repr(self):
     def test_task_wrapper_repr(self):
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         self.assertTrue(repr(tw))
         self.assertTrue(repr(tw))
 
 
     @patch('celery.worker.job.kwdict')
     @patch('celery.worker.job.kwdict')
@@ -358,7 +358,7 @@ class test_Request(AppCase):
 
 
         prev, module.NEEDS_KWDICT = module.NEEDS_KWDICT, True
         prev, module.NEEDS_KWDICT = module.NEEDS_KWDICT, True
         try:
         try:
-            TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+            TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
             self.assertTrue(kwdict.called)
             self.assertTrue(kwdict.called)
         finally:
         finally:
             module.NEEDS_KWDICT = prev
             module.NEEDS_KWDICT = prev
@@ -366,23 +366,25 @@ class test_Request(AppCase):
     def test_sets_store_errors(self):
     def test_sets_store_errors(self):
         mytask.ignore_result = True
         mytask.ignore_result = True
         try:
         try:
-            tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+            tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'},
+                             app=self.app)
             self.assertFalse(tw.store_errors)
             self.assertFalse(tw.store_errors)
             mytask.store_errors_even_if_ignored = True
             mytask.store_errors_even_if_ignored = True
-            tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+            tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'},
+                             app=self.app)
             self.assertTrue(tw.store_errors)
             self.assertTrue(tw.store_errors)
         finally:
         finally:
             mytask.ignore_result = False
             mytask.ignore_result = False
             mytask.store_errors_even_if_ignored = False
             mytask.store_errors_even_if_ignored = False
 
 
     def test_send_event(self):
     def test_send_event(self):
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         tw.eventer = MockEventDispatcher()
         tw.eventer = MockEventDispatcher()
         tw.send_event('task-frobulated')
         tw.send_event('task-frobulated')
         self.assertIn('task-frobulated', tw.eventer.sent)
         self.assertIn('task-frobulated', tw.eventer.sent)
 
 
     def test_on_retry(self):
     def test_on_retry(self):
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         tw.eventer = MockEventDispatcher()
         tw.eventer = MockEventDispatcher()
         try:
         try:
             raise RetryTaskError('foo', KeyError('moofoobar'))
             raise RetryTaskError('foo', KeyError('moofoobar'))
@@ -399,7 +401,7 @@ class test_Request(AppCase):
             tw.on_failure(einfo)
             tw.on_failure(einfo)
 
 
     def test_compat_properties(self):
     def test_compat_properties(self):
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         self.assertEqual(tw.task_id, tw.id)
         self.assertEqual(tw.task_id, tw.id)
         self.assertEqual(tw.task_name, tw.name)
         self.assertEqual(tw.task_name, tw.name)
         tw.task_id = 'ID'
         tw.task_id = 'ID'
@@ -410,7 +412,7 @@ class test_Request(AppCase):
     def test_terminate__task_started(self):
     def test_terminate__task_started(self):
         pool = Mock()
         pool = Mock()
         signum = signal.SIGKILL
         signum = signal.SIGKILL
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         with assert_signal_called(task_revoked, sender=tw.task,
         with assert_signal_called(task_revoked, sender=tw.task,
                                   terminated=True,
                                   terminated=True,
                                   expired=False,
                                   expired=False,
@@ -422,7 +424,7 @@ class test_Request(AppCase):
 
 
     def test_terminate__task_reserved(self):
     def test_terminate__task_reserved(self):
         pool = Mock()
         pool = Mock()
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         tw.time_start = None
         tw.time_start = None
         tw.terminate(pool, signal='KILL')
         tw.terminate(pool, signal='KILL')
         self.assertFalse(pool.terminate_job.called)
         self.assertFalse(pool.terminate_job.called)
@@ -431,7 +433,8 @@ class test_Request(AppCase):
 
 
     def test_revoked_expires_expired(self):
     def test_revoked_expires_expired(self):
         tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'},
         tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'},
-                         expires=datetime.utcnow() - timedelta(days=1))
+                         expires=datetime.utcnow() - timedelta(days=1),
+                         app=self.app)
         with assert_signal_called(task_revoked, sender=tw.task,
         with assert_signal_called(task_revoked, sender=tw.task,
                                   terminated=False,
                                   terminated=False,
                                   expired=True,
                                   expired=True,
@@ -443,7 +446,8 @@ class test_Request(AppCase):
 
 
     def test_revoked_expires_not_expired(self):
     def test_revoked_expires_not_expired(self):
         tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'},
         tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'},
-                         expires=datetime.utcnow() + timedelta(days=1))
+                         expires=datetime.utcnow() + timedelta(days=1),
+                         app=self.app)
         tw.revoked()
         tw.revoked()
         self.assertNotIn(tw.id, revoked)
         self.assertNotIn(tw.id, revoked)
         self.assertNotEqual(
         self.assertNotEqual(
@@ -454,7 +458,8 @@ class test_Request(AppCase):
     def test_revoked_expires_ignore_result(self):
     def test_revoked_expires_ignore_result(self):
         mytask.ignore_result = True
         mytask.ignore_result = True
         tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'},
         tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'},
-                         expires=datetime.utcnow() - timedelta(days=1))
+                         expires=datetime.utcnow() - timedelta(days=1),
+                         app=self.app)
         try:
         try:
             tw.revoked()
             tw.revoked()
             self.assertIn(tw.id, revoked)
             self.assertIn(tw.id, revoked)
@@ -482,7 +487,8 @@ class test_Request(AppCase):
         app.mail_admins = mock_mail_admins
         app.mail_admins = mock_mail_admins
         mytask.send_error_emails = True
         mytask.send_error_emails = True
         try:
         try:
-            tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+            tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'},
+                             app=self.app)
 
 
             einfo = get_ei()
             einfo = get_ei()
             tw.on_failure(einfo)
             tw.on_failure(einfo)
@@ -505,12 +511,12 @@ class test_Request(AppCase):
             mytask.send_error_emails = old_enable_mails
             mytask.send_error_emails = old_enable_mails
 
 
     def test_already_revoked(self):
     def test_already_revoked(self):
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         tw._already_revoked = True
         tw._already_revoked = True
         self.assertTrue(tw.revoked())
         self.assertTrue(tw.revoked())
 
 
     def test_revoked(self):
     def test_revoked(self):
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         with assert_signal_called(task_revoked, sender=tw.task,
         with assert_signal_called(task_revoked, sender=tw.task,
                                   terminated=False,
                                   terminated=False,
                                   expired=False,
                                   expired=False,
@@ -521,13 +527,13 @@ class test_Request(AppCase):
             self.assertTrue(tw.acknowledged)
             self.assertTrue(tw.acknowledged)
 
 
     def test_execute_does_not_execute_revoked(self):
     def test_execute_does_not_execute_revoked(self):
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         revoked.add(tw.id)
         revoked.add(tw.id)
         tw.execute()
         tw.execute()
 
 
     def test_execute_acks_late(self):
     def test_execute_acks_late(self):
         mytask_raising.acks_late = True
         mytask_raising.acks_late = True
-        tw = TaskRequest(mytask_raising.name, uuid(), [1])
+        tw = TaskRequest(mytask_raising.name, uuid(), [1], app=self.app)
         try:
         try:
             tw.execute()
             tw.execute()
             self.assertTrue(tw.acknowledged)
             self.assertTrue(tw.acknowledged)
@@ -537,13 +543,13 @@ class test_Request(AppCase):
             mytask_raising.acks_late = False
             mytask_raising.acks_late = False
 
 
     def test_execute_using_pool_does_not_execute_revoked(self):
     def test_execute_using_pool_does_not_execute_revoked(self):
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         revoked.add(tw.id)
         revoked.add(tw.id)
         with self.assertRaises(TaskRevokedError):
         with self.assertRaises(TaskRevokedError):
             tw.execute_using_pool(None)
             tw.execute_using_pool(None)
 
 
     def test_on_accepted_acks_early(self):
     def test_on_accepted_acks_early(self):
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         tw.on_accepted(pid=os.getpid(), time_accepted=time.time())
         tw.on_accepted(pid=os.getpid(), time_accepted=time.time())
         self.assertTrue(tw.acknowledged)
         self.assertTrue(tw.acknowledged)
         prev, module._does_debug = module._does_debug, False
         prev, module._does_debug = module._does_debug, False
@@ -553,7 +559,7 @@ class test_Request(AppCase):
             module._does_debug = prev
             module._does_debug = prev
 
 
     def test_on_accepted_acks_late(self):
     def test_on_accepted_acks_late(self):
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         mytask.acks_late = True
         mytask.acks_late = True
         try:
         try:
             tw.on_accepted(pid=os.getpid(), time_accepted=time.time())
             tw.on_accepted(pid=os.getpid(), time_accepted=time.time())
@@ -564,7 +570,7 @@ class test_Request(AppCase):
     def test_on_accepted_terminates(self):
     def test_on_accepted_terminates(self):
         signum = signal.SIGKILL
         signum = signal.SIGKILL
         pool = Mock()
         pool = Mock()
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         with assert_signal_called(task_revoked, sender=tw.task,
         with assert_signal_called(task_revoked, sender=tw.task,
                                   terminated=True,
                                   terminated=True,
                                   expired=False,
                                   expired=False,
@@ -575,7 +581,7 @@ class test_Request(AppCase):
             pool.terminate_job.assert_called_with(314, signum)
             pool.terminate_job.assert_called_with(314, signum)
 
 
     def test_on_success_acks_early(self):
     def test_on_success_acks_early(self):
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         tw.time_start = 1
         tw.time_start = 1
         tw.on_success(42)
         tw.on_success(42)
         prev, module._does_info = module._does_info, False
         prev, module._does_info = module._does_info, False
@@ -586,7 +592,7 @@ class test_Request(AppCase):
             module._does_info = prev
             module._does_info = prev
 
 
     def test_on_success_BaseException(self):
     def test_on_success_BaseException(self):
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         tw.time_start = 1
         tw.time_start = 1
         with self.assertRaises(SystemExit):
         with self.assertRaises(SystemExit):
             try:
             try:
@@ -597,7 +603,7 @@ class test_Request(AppCase):
                 assert False
                 assert False
 
 
     def test_on_success_eventer(self):
     def test_on_success_eventer(self):
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         tw.time_start = 1
         tw.time_start = 1
         tw.eventer = Mock()
         tw.eventer = Mock()
         tw.send_event = Mock()
         tw.send_event = Mock()
@@ -605,7 +611,7 @@ class test_Request(AppCase):
         self.assertTrue(tw.send_event.called)
         self.assertTrue(tw.send_event.called)
 
 
     def test_on_success_when_failure(self):
     def test_on_success_when_failure(self):
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         tw.time_start = 1
         tw.time_start = 1
         tw.on_failure = Mock()
         tw.on_failure = Mock()
         try:
         try:
@@ -615,7 +621,7 @@ class test_Request(AppCase):
             self.assertTrue(tw.on_failure.called)
             self.assertTrue(tw.on_failure.called)
 
 
     def test_on_success_acks_late(self):
     def test_on_success_acks_late(self):
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         tw.time_start = 1
         tw.time_start = 1
         mytask.acks_late = True
         mytask.acks_late = True
         try:
         try:
@@ -632,7 +638,7 @@ class test_Request(AppCase):
             except WorkerLostError:
             except WorkerLostError:
                 return ExceptionInfo()
                 return ExceptionInfo()
 
 
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         exc_info = get_ei()
         exc_info = get_ei()
         tw.on_failure(exc_info)
         tw.on_failure(exc_info)
         self.assertEqual(mytask.backend.get_status(tw.id),
         self.assertEqual(mytask.backend.get_status(tw.id),
@@ -641,7 +647,8 @@ class test_Request(AppCase):
         mytask.ignore_result = True
         mytask.ignore_result = True
         try:
         try:
             exc_info = get_ei()
             exc_info = get_ei()
-            tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+            tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'},
+                             app=self.app)
             tw.on_failure(exc_info)
             tw.on_failure(exc_info)
             self.assertEqual(mytask.backend.get_status(tw.id),
             self.assertEqual(mytask.backend.get_status(tw.id),
                              states.PENDING)
                              states.PENDING)
@@ -649,7 +656,7 @@ class test_Request(AppCase):
             mytask.ignore_result = False
             mytask.ignore_result = False
 
 
     def test_on_failure_acks_late(self):
     def test_on_failure_acks_late(self):
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         tw.time_start = 1
         tw.time_start = 1
         mytask.acks_late = True
         mytask.acks_late = True
         try:
         try:
@@ -665,13 +672,13 @@ class test_Request(AppCase):
     def test_from_message_invalid_kwargs(self):
     def test_from_message_invalid_kwargs(self):
         body = dict(task=mytask.name, id=1, args=(), kwargs='foo')
         body = dict(task=mytask.name, id=1, args=(), kwargs='foo')
         with self.assertRaises(InvalidTaskError):
         with self.assertRaises(InvalidTaskError):
-            TaskRequest.from_message(None, body)
+            TaskRequest.from_message(None, body, app=self.app)
 
 
     @patch('celery.worker.job.error')
     @patch('celery.worker.job.error')
     @patch('celery.worker.job.warn')
     @patch('celery.worker.job.warn')
     def test_on_timeout(self, warn, error):
     def test_on_timeout(self, warn, error):
 
 
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         tw.on_timeout(soft=True, timeout=1337)
         tw.on_timeout(soft=True, timeout=1337)
         self.assertIn('Soft time limit', warn.call_args[0][0])
         self.assertIn('Soft time limit', warn.call_args[0][0])
         tw.on_timeout(soft=False, timeout=1337)
         tw.on_timeout(soft=False, timeout=1337)
@@ -681,7 +688,8 @@ class test_Request(AppCase):
 
 
         mytask.ignore_result = True
         mytask.ignore_result = True
         try:
         try:
-            tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+            tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'},
+                             app=self.app)
             tw.on_timeout(soft=True, timeout=1336)
             tw.on_timeout(soft=True, timeout=1336)
             self.assertEqual(mytask.backend.get_status(tw.id),
             self.assertEqual(mytask.backend.get_status(tw.id),
                              states.PENDING)
                              states.PENDING)
@@ -771,7 +779,7 @@ class test_Request(AppCase):
             mytask.pop_request()
             mytask.pop_request()
 
 
     def test_task_wrapper_mail_attrs(self):
     def test_task_wrapper_mail_attrs(self):
-        tw = TaskRequest(mytask.name, uuid(), [], {})
+        tw = TaskRequest(mytask.name, uuid(), [], {}, app=self.app)
         x = tw.success_msg % {
         x = tw.success_msg % {
             'name': tw.name,
             'name': tw.name,
             'id': tw.id,
             'id': tw.id,
@@ -794,7 +802,7 @@ class test_Request(AppCase):
         m = Message(None, body=anyjson.dumps(body), backend='foo',
         m = Message(None, body=anyjson.dumps(body), backend='foo',
                     content_type='application/json',
                     content_type='application/json',
                     content_encoding='utf-8')
                     content_encoding='utf-8')
-        tw = TaskRequest.from_message(m, m.decode())
+        tw = TaskRequest.from_message(m, m.decode(), app=self.app)
         self.assertIsInstance(tw, Request)
         self.assertIsInstance(tw, Request)
         self.assertEqual(tw.name, body['task'])
         self.assertEqual(tw.name, body['task'])
         self.assertEqual(tw.id, body['id'])
         self.assertEqual(tw.id, body['id'])
@@ -809,7 +817,7 @@ class test_Request(AppCase):
         m = Message(None, body=anyjson.dumps(body), backend='foo',
         m = Message(None, body=anyjson.dumps(body), backend='foo',
                     content_type='application/json',
                     content_type='application/json',
                     content_encoding='utf-8')
                     content_encoding='utf-8')
-        tw = TaskRequest.from_message(m, m.decode())
+        tw = TaskRequest.from_message(m, m.decode(), app=self.app)
         self.assertIsInstance(tw, Request)
         self.assertIsInstance(tw, Request)
         self.assertEquals(tw.args, [])
         self.assertEquals(tw.args, [])
         self.assertEquals(tw.kwargs, {})
         self.assertEquals(tw.kwargs, {})
@@ -820,7 +828,7 @@ class test_Request(AppCase):
                     content_type='application/json',
                     content_type='application/json',
                     content_encoding='utf-8')
                     content_encoding='utf-8')
         with self.assertRaises(KeyError):
         with self.assertRaises(KeyError):
-            TaskRequest.from_message(m, m.decode())
+            TaskRequest.from_message(m, m.decode(), app=self.app)
 
 
     def test_from_message_nonexistant_task(self):
     def test_from_message_nonexistant_task(self):
         body = {'task': 'cu.mytask.doesnotexist', 'id': uuid(),
         body = {'task': 'cu.mytask.doesnotexist', 'id': uuid(),
@@ -829,11 +837,11 @@ class test_Request(AppCase):
                     content_type='application/json',
                     content_type='application/json',
                     content_encoding='utf-8')
                     content_encoding='utf-8')
         with self.assertRaises(KeyError):
         with self.assertRaises(KeyError):
-            TaskRequest.from_message(m, m.decode())
+            TaskRequest.from_message(m, m.decode(), app=self.app)
 
 
     def test_execute(self):
     def test_execute(self):
         tid = uuid()
         tid = uuid()
-        tw = TaskRequest(mytask.name, tid, [4], {'f': 'x'})
+        tw = TaskRequest(mytask.name, tid, [4], {'f': 'x'}, app=self.app)
         self.assertEqual(tw.execute(), 256)
         self.assertEqual(tw.execute(), 256)
         meta = mytask.backend.get_task_meta(tid)
         meta = mytask.backend.get_task_meta(tid)
         self.assertEqual(meta['result'], 256)
         self.assertEqual(meta['result'], 256)
@@ -841,7 +849,7 @@ class test_Request(AppCase):
 
 
     def test_execute_success_no_kwargs(self):
     def test_execute_success_no_kwargs(self):
         tid = uuid()
         tid = uuid()
-        tw = TaskRequest(mytask_no_kwargs.name, tid, [4], {})
+        tw = TaskRequest(mytask_no_kwargs.name, tid, [4], {}, app=self.app)
         self.assertEqual(tw.execute(), 256)
         self.assertEqual(tw.execute(), 256)
         meta = mytask_no_kwargs.backend.get_task_meta(tid)
         meta = mytask_no_kwargs.backend.get_task_meta(tid)
         self.assertEqual(meta['result'], 256)
         self.assertEqual(meta['result'], 256)
@@ -849,7 +857,7 @@ class test_Request(AppCase):
 
 
     def test_execute_success_some_kwargs(self):
     def test_execute_success_some_kwargs(self):
         tid = uuid()
         tid = uuid()
-        tw = TaskRequest(mytask_some_kwargs.name, tid, [4], {})
+        tw = TaskRequest(mytask_some_kwargs.name, tid, [4], {}, app=self.app)
         self.assertEqual(tw.execute(), 256)
         self.assertEqual(tw.execute(), 256)
         meta = mytask_some_kwargs.backend.get_task_meta(tid)
         meta = mytask_some_kwargs.backend.get_task_meta(tid)
         self.assertEqual(some_kwargs_scratchpad.get('task_id'), tid)
         self.assertEqual(some_kwargs_scratchpad.get('task_id'), tid)
@@ -859,7 +867,7 @@ class test_Request(AppCase):
     def test_execute_ack(self):
     def test_execute_ack(self):
         tid = uuid()
         tid = uuid()
         tw = TaskRequest(mytask.name, tid, [4], {'f': 'x'},
         tw = TaskRequest(mytask.name, tid, [4], {'f': 'x'},
-                         on_ack=on_ack)
+                         on_ack=on_ack, app=self.app)
         self.assertEqual(tw.execute(), 256)
         self.assertEqual(tw.execute(), 256)
         meta = mytask.backend.get_task_meta(tid)
         meta = mytask.backend.get_task_meta(tid)
         self.assertTrue(scratch['ACK'])
         self.assertTrue(scratch['ACK'])
@@ -868,7 +876,7 @@ class test_Request(AppCase):
 
 
     def test_execute_fail(self):
     def test_execute_fail(self):
         tid = uuid()
         tid = uuid()
-        tw = TaskRequest(mytask_raising.name, tid, [4])
+        tw = TaskRequest(mytask_raising.name, tid, [4], app=self.app)
         self.assertIsInstance(tw.execute(), ExceptionInfo)
         self.assertIsInstance(tw.execute(), ExceptionInfo)
         meta = mytask_raising.backend.get_task_meta(tid)
         meta = mytask_raising.backend.get_task_meta(tid)
         self.assertEqual(meta['status'], states.FAILURE)
         self.assertEqual(meta['status'], states.FAILURE)
@@ -876,7 +884,7 @@ class test_Request(AppCase):
 
 
     def test_execute_using_pool(self):
     def test_execute_using_pool(self):
         tid = uuid()
         tid = uuid()
-        tw = TaskRequest(mytask.name, tid, [4], {'f': 'x'})
+        tw = TaskRequest(mytask.name, tid, [4], {'f': 'x'}, app=self.app)
 
 
         class MockPool(BasePool):
         class MockPool(BasePool):
             target = None
             target = None
@@ -906,7 +914,7 @@ class test_Request(AppCase):
 
 
     def test_default_kwargs(self):
     def test_default_kwargs(self):
         tid = uuid()
         tid = uuid()
-        tw = TaskRequest(mytask.name, tid, [4], {'f': 'x'})
+        tw = TaskRequest(mytask.name, tid, [4], {'f': 'x'}, app=self.app)
         self.assertDictEqual(
         self.assertDictEqual(
             tw.extend_with_default_kwargs(), {
             tw.extend_with_default_kwargs(), {
                 'f': 'x',
                 'f': 'x',
@@ -926,7 +934,7 @@ class test_Request(AppCase):
     def _test_on_failure(self, exception, logger):
     def _test_on_failure(self, exception, logger):
         app = self.app
         app = self.app
         tid = uuid()
         tid = uuid()
-        tw = TaskRequest(mytask.name, tid, [4], {'f': 'x'})
+        tw = TaskRequest(mytask.name, tid, [4], {'f': 'x'}, app=self.app)
         try:
         try:
             raise exception
             raise exception
         except Exception:
         except Exception:

+ 41 - 40
celery/tests/worker/test_worker.py

@@ -23,7 +23,6 @@ from celery.five import Empty, range, Queue as FastQueue
 from celery.task import task as task_dec
 from celery.task import task as task_dec
 from celery.task import periodic_task as periodic_task_dec
 from celery.task import periodic_task as periodic_task_dec
 from celery.utils import uuid
 from celery.utils import uuid
-from celery.worker import WorkController
 from celery.worker import components
 from celery.worker import components
 from celery.worker import consumer
 from celery.worker import consumer
 from celery.worker.consumer import Consumer as __Consumer
 from celery.worker.consumer import Consumer as __Consumer
@@ -243,7 +242,7 @@ class test_Consumer(AppCase):
         self.timer.stop()
         self.timer.stop()
 
 
     def test_info(self):
     def test_info(self):
-        l = MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.task_consumer = Mock()
         l.task_consumer = Mock()
         l.qos = QoS(l.task_consumer.qos, 10)
         l.qos = QoS(l.task_consumer.qos, 10)
         l.connection = Mock()
         l.connection = Mock()
@@ -257,12 +256,12 @@ class test_Consumer(AppCase):
         self.assertTrue(info['broker'])
         self.assertTrue(info['broker'])
 
 
     def test_start_when_closed(self):
     def test_start_when_closed(self):
-        l = MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.blueprint.state = CLOSE
         l.blueprint.state = CLOSE
         l.start()
         l.start()
 
 
     def test_connection(self):
     def test_connection(self):
-        l = MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
 
 
         l.blueprint.start(l)
         l.blueprint.start(l)
         self.assertIsInstance(l.connection, Connection)
         self.assertIsInstance(l.connection, Connection)
@@ -287,7 +286,7 @@ class test_Consumer(AppCase):
         self.assertIsNone(l.task_consumer)
         self.assertIsNone(l.task_consumer)
 
 
     def test_close_connection(self):
     def test_close_connection(self):
-        l = MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.blueprint.state = RUN
         l.blueprint.state = RUN
         step = find_step(l, consumer.Connection)
         step = find_step(l, consumer.Connection)
         conn = l.connection = Mock()
         conn = l.connection = Mock()
@@ -295,7 +294,7 @@ class test_Consumer(AppCase):
         self.assertTrue(conn.close.called)
         self.assertTrue(conn.close.called)
         self.assertIsNone(l.connection)
         self.assertIsNone(l.connection)
 
 
-        l = MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         eventer = l.event_dispatcher = Mock()
         eventer = l.event_dispatcher = Mock()
         eventer.enabled = True
         eventer.enabled = True
         heart = l.heart = MockHeart()
         heart = l.heart = MockHeart()
@@ -309,7 +308,7 @@ class test_Consumer(AppCase):
 
 
     @patch('celery.worker.consumer.warn')
     @patch('celery.worker.consumer.warn')
     def test_receive_message_unknown(self, warn):
     def test_receive_message_unknown(self, warn):
-        l = _MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.steps.pop()
         l.steps.pop()
         backend = Mock()
         backend = Mock()
         m = create_message(backend, unknown={'baz': '!!!'})
         m = create_message(backend, unknown={'baz': '!!!'})
@@ -323,7 +322,7 @@ class test_Consumer(AppCase):
     @patch('celery.worker.strategy.to_timestamp')
     @patch('celery.worker.strategy.to_timestamp')
     def test_receive_message_eta_OverflowError(self, to_timestamp):
     def test_receive_message_eta_OverflowError(self, to_timestamp):
         to_timestamp.side_effect = OverflowError()
         to_timestamp.side_effect = OverflowError()
-        l = _MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.steps.pop()
         l.steps.pop()
         m = create_message(Mock(), task=foo_task.name,
         m = create_message(Mock(), task=foo_task.name,
                            args=('2, 2'),
                            args=('2, 2'),
@@ -340,7 +339,7 @@ class test_Consumer(AppCase):
 
 
     @patch('celery.worker.consumer.error')
     @patch('celery.worker.consumer.error')
     def test_receive_message_InvalidTaskError(self, error):
     def test_receive_message_InvalidTaskError(self, error):
-        l = _MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.event_dispatcher = Mock()
         l.event_dispatcher = Mock()
         l.steps.pop()
         l.steps.pop()
         m = create_message(Mock(), task=foo_task.name,
         m = create_message(Mock(), task=foo_task.name,
@@ -354,7 +353,7 @@ class test_Consumer(AppCase):
 
 
     @patch('celery.worker.consumer.crit')
     @patch('celery.worker.consumer.crit')
     def test_on_decode_error(self, crit):
     def test_on_decode_error(self, crit):
-        l = Consumer(self.buffer.put, timer=self.timer)
+        l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
 
 
         class MockMessage(Mock):
         class MockMessage(Mock):
             content_type = 'application/x-msgpack'
             content_type = 'application/x-msgpack'
@@ -380,7 +379,7 @@ class test_Consumer(AppCase):
         return l.task_consumer.register_callback.call_args[0][0]
         return l.task_consumer.register_callback.call_args[0][0]
 
 
     def test_receieve_message(self):
     def test_receieve_message(self):
-        l = Consumer(self.buffer.put, timer=self.timer)
+        l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
         l.event_dispatcher = Mock()
         l.event_dispatcher = Mock()
         m = create_message(Mock(), task=foo_task.name,
         m = create_message(Mock(), task=foo_task.name,
                            args=[2, 4, 8], kwargs={})
                            args=[2, 4, 8], kwargs={})
@@ -406,7 +405,7 @@ class test_Consumer(AppCase):
                 raise SyntaxError('bar')
                 raise SyntaxError('bar')
 
 
         l = MockConsumer(self.buffer.put, timer=self.timer,
         l = MockConsumer(self.buffer.put, timer=self.timer,
-                         send_events=False, pool=BasePool())
+                         send_events=False, pool=BasePool(), app=self.app)
         l.channel_errors = (KeyError, )
         l.channel_errors = (KeyError, )
         with self.assertRaises(KeyError):
         with self.assertRaises(KeyError):
             l.start()
             l.start()
@@ -424,7 +423,7 @@ class test_Consumer(AppCase):
                 raise SyntaxError('bar')
                 raise SyntaxError('bar')
 
 
         l = MockConsumer(self.buffer.put, timer=self.timer,
         l = MockConsumer(self.buffer.put, timer=self.timer,
-                         send_events=False, pool=BasePool())
+                         send_events=False, pool=BasePool(), app=self.app)
 
 
         l.connection_errors = (KeyError, )
         l.connection_errors = (KeyError, )
         self.assertRaises(SyntaxError, l.start)
         self.assertRaises(SyntaxError, l.start)
@@ -439,7 +438,7 @@ class test_Consumer(AppCase):
                 self.obj.connection = None
                 self.obj.connection = None
                 raise socket.timeout(10)
                 raise socket.timeout(10)
 
 
-        l = MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.connection = Connection()
         l.connection = Connection()
         l.task_consumer = Mock()
         l.task_consumer = Mock()
         l.connection.obj = l
         l.connection.obj = l
@@ -455,7 +454,7 @@ class test_Consumer(AppCase):
                 self.obj.connection = None
                 self.obj.connection = None
                 raise socket.error('foo')
                 raise socket.error('foo')
 
 
-        l = Consumer(self.buffer.put, timer=self.timer)
+        l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
         l.blueprint.state = RUN
         l.blueprint.state = RUN
         c = l.connection = Connection()
         c = l.connection = Connection()
         l.connection.obj = l
         l.connection.obj = l
@@ -476,7 +475,7 @@ class test_Consumer(AppCase):
             def drain_events(self, **kwargs):
             def drain_events(self, **kwargs):
                 self.obj.connection = None
                 self.obj.connection = None
 
 
-        l = Consumer(self.buffer.put, timer=self.timer)
+        l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
         l.connection = Connection()
         l.connection = Connection()
         l.connection.obj = l
         l.connection.obj = l
         l.task_consumer = Mock()
         l.task_consumer = Mock()
@@ -494,7 +493,7 @@ class test_Consumer(AppCase):
         l.task_consumer.qos.assert_called_with(prefetch_count=9)
         l.task_consumer.qos.assert_called_with(prefetch_count=9)
 
 
     def test_ignore_errors(self):
     def test_ignore_errors(self):
-        l = MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.connection_errors = (AttributeError, KeyError, )
         l.connection_errors = (AttributeError, KeyError, )
         l.channel_errors = (SyntaxError, )
         l.channel_errors = (SyntaxError, )
         ignore_errors(l, Mock(side_effect=AttributeError('foo')))
         ignore_errors(l, Mock(side_effect=AttributeError('foo')))
@@ -505,7 +504,7 @@ class test_Consumer(AppCase):
 
 
     def test_apply_eta_task(self):
     def test_apply_eta_task(self):
         from celery.worker import state
         from celery.worker import state
-        l = MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.qos = QoS(None, 10)
         l.qos = QoS(None, 10)
 
 
         task = object()
         task = object()
@@ -516,7 +515,7 @@ class test_Consumer(AppCase):
         self.assertIs(self.buffer.get_nowait(), task)
         self.assertIs(self.buffer.get_nowait(), task)
 
 
     def test_receieve_message_eta_isoformat(self):
     def test_receieve_message_eta_isoformat(self):
-        l = _MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.steps.pop()
         l.steps.pop()
         m = create_message(
         m = create_message(
             Mock(), task=foo_task.name,
             Mock(), task=foo_task.name,
@@ -545,7 +544,7 @@ class test_Consumer(AppCase):
         l.timer.stop()
         l.timer.stop()
 
 
     def test_pidbox_callback(self):
     def test_pidbox_callback(self):
-        l = MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         con = find_step(l, consumer.Control).box
         con = find_step(l, consumer.Control).box
         con.node = Mock()
         con.node = Mock()
         con.reset = Mock()
         con.reset = Mock()
@@ -565,7 +564,7 @@ class test_Consumer(AppCase):
         self.assertTrue(con.reset.called)
         self.assertTrue(con.reset.called)
 
 
     def test_revoke(self):
     def test_revoke(self):
-        l = _MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.steps.pop()
         l.steps.pop()
         backend = Mock()
         backend = Mock()
         id = uuid()
         id = uuid()
@@ -579,7 +578,7 @@ class test_Consumer(AppCase):
         self.assertTrue(self.buffer.empty())
         self.assertTrue(self.buffer.empty())
 
 
     def test_receieve_message_not_registered(self):
     def test_receieve_message_not_registered(self):
-        l = _MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.steps.pop()
         l.steps.pop()
         backend = Mock()
         backend = Mock()
         m = create_message(backend, task='x.X.31x', args=[2, 4, 8], kwargs={})
         m = create_message(backend, task='x.X.31x', args=[2, 4, 8], kwargs={})
@@ -594,7 +593,7 @@ class test_Consumer(AppCase):
     @patch('celery.worker.consumer.warn')
     @patch('celery.worker.consumer.warn')
     @patch('celery.worker.consumer.logger')
     @patch('celery.worker.consumer.logger')
     def test_receieve_message_ack_raises(self, logger, warn):
     def test_receieve_message_ack_raises(self, logger, warn):
-        l = Consumer(self.buffer.put, timer=self.timer)
+        l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
         backend = Mock()
         backend = Mock()
         m = create_message(backend, args=[2, 4, 8], kwargs={})
         m = create_message(backend, args=[2, 4, 8], kwargs={})
 
 
@@ -612,7 +611,7 @@ class test_Consumer(AppCase):
         self.assertTrue(logger.critical.call_count)
         self.assertTrue(logger.critical.call_count)
 
 
     def test_receive_message_eta(self):
     def test_receive_message_eta(self):
-        l = _MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.steps.pop()
         l.steps.pop()
         l.event_dispatcher = Mock()
         l.event_dispatcher = Mock()
         l.event_dispatcher._outbound_buffer = deque()
         l.event_dispatcher._outbound_buffer = deque()
@@ -646,7 +645,7 @@ class test_Consumer(AppCase):
             self.buffer.get_nowait()
             self.buffer.get_nowait()
 
 
     def test_reset_pidbox_node(self):
     def test_reset_pidbox_node(self):
-        l = MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         con = find_step(l, consumer.Control).box
         con = find_step(l, consumer.Control).box
         con.node = Mock()
         con.node = Mock()
         chan = con.node.channel = Mock()
         chan = con.node.channel = Mock()
@@ -660,7 +659,8 @@ class test_Consumer(AppCase):
         from celery.worker.pidbox import gPidbox
         from celery.worker.pidbox import gPidbox
         pool = Mock()
         pool = Mock()
         pool.is_green = True
         pool.is_green = True
-        l = MyKombuConsumer(self.buffer.put, timer=self.timer, pool=pool)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer, pool=pool,
+                            app=self.app)
         con = find_step(l, consumer.Control)
         con = find_step(l, consumer.Control)
         self.assertIsInstance(con.box, gPidbox)
         self.assertIsInstance(con.box, gPidbox)
         con.start(l)
         con.start(l)
@@ -671,7 +671,8 @@ class test_Consumer(AppCase):
     def test__green_pidbox_node(self):
     def test__green_pidbox_node(self):
         pool = Mock()
         pool = Mock()
         pool.is_green = True
         pool.is_green = True
-        l = MyKombuConsumer(self.buffer.put, timer=self.timer, pool=pool)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer, pool=pool,
+                            app=self.app)
         l.node = Mock()
         l.node = Mock()
         controller = find_step(l, consumer.Control)
         controller = find_step(l, consumer.Control)
 
 
@@ -733,7 +734,7 @@ class test_Consumer(AppCase):
     @patch('kombu.connection.Connection._establish_connection')
     @patch('kombu.connection.Connection._establish_connection')
     @patch('kombu.utils.sleep')
     @patch('kombu.utils.sleep')
     def test_connect_errback(self, sleep, connect):
     def test_connect_errback(self, sleep, connect):
-        l = MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         from kombu.transport.memory import Transport
         from kombu.transport.memory import Transport
         Transport.connection_errors = (StdChannelError, )
         Transport.connection_errors = (StdChannelError, )
 
 
@@ -746,7 +747,7 @@ class test_Consumer(AppCase):
         connect.assert_called_with()
         connect.assert_called_with()
 
 
     def test_stop_pidbox_node(self):
     def test_stop_pidbox_node(self):
-        l = MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         cont = find_step(l, consumer.Control)
         cont = find_step(l, consumer.Control)
         cont._node_stopped = Event()
         cont._node_stopped = Event()
         cont._node_shutdown = Event()
         cont._node_shutdown = Event()
@@ -771,7 +772,7 @@ class test_Consumer(AppCase):
 
 
         init_callback = Mock()
         init_callback = Mock()
         l = _Consumer(self.buffer.put, timer=self.timer,
         l = _Consumer(self.buffer.put, timer=self.timer,
-                      init_callback=init_callback)
+                      init_callback=init_callback, app=self.app)
         l.task_consumer = Mock()
         l.task_consumer = Mock()
         l.broadcast_consumer = Mock()
         l.broadcast_consumer = Mock()
         l.qos = _QoS()
         l.qos = _QoS()
@@ -792,7 +793,7 @@ class test_Consumer(AppCase):
         self.assertEqual(l.qos.prev, l.qos.value)
         self.assertEqual(l.qos.prev, l.qos.value)
 
 
         init_callback.reset_mock()
         init_callback.reset_mock()
-        l = _Consumer(self.buffer.put, timer=self.timer,
+        l = _Consumer(self.buffer.put, timer=self.timer, app=self.app,
                       send_events=False, init_callback=init_callback)
                       send_events=False, init_callback=init_callback)
         l.qos = _QoS()
         l.qos = _QoS()
         l.task_consumer = Mock()
         l.task_consumer = Mock()
@@ -804,7 +805,7 @@ class test_Consumer(AppCase):
         self.assertTrue(l.loop.call_count)
         self.assertTrue(l.loop.call_count)
 
 
     def test_reset_connection_with_no_node(self):
     def test_reset_connection_with_no_node(self):
-        l = Consumer(self.buffer.put, timer=self.timer)
+        l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
         l.steps.pop()
         l.steps.pop()
         self.assertEqual(None, l.pool)
         self.assertEqual(None, l.pool)
         l.blueprint.start(l)
         l.blueprint.start(l)
@@ -921,7 +922,7 @@ class test_WorkController(AppCase):
         self.assertTrue(worker.steps)
         self.assertTrue(worker.steps)
 
 
     def test_with_embedded_beat(self):
     def test_with_embedded_beat(self):
-        worker = WorkController(concurrency=1, loglevel=0, beat=True)
+        worker = self.app.WorkController(concurrency=1, loglevel=0, beat=True)
         self.assertTrue(worker.beat)
         self.assertTrue(worker.beat)
         self.assertIn(worker.beat, [w.obj for w in worker.steps])
         self.assertIn(worker.beat, [w.obj for w in worker.steps])
 
 
@@ -933,7 +934,7 @@ class test_WorkController(AppCase):
         self.assertTrue(worker.autoscaler)
         self.assertTrue(worker.autoscaler)
 
 
     def test_dont_stop_or_terminate(self):
     def test_dont_stop_or_terminate(self):
-        worker = WorkController(concurrency=1, loglevel=0)
+        worker = self.app.WorkController(concurrency=1, loglevel=0)
         worker.stop()
         worker.stop()
         self.assertNotEqual(worker.blueprint.state, CLOSE)
         self.assertNotEqual(worker.blueprint.state, CLOSE)
         worker.terminate()
         worker.terminate()
@@ -950,7 +951,7 @@ class test_WorkController(AppCase):
             worker.pool.signal_safe = sigsafe
             worker.pool.signal_safe = sigsafe
 
 
     def test_on_timer_error(self):
     def test_on_timer_error(self):
-        worker = WorkController(concurrency=1, loglevel=0)
+        worker = self.app.WorkController(concurrency=1, loglevel=0)
 
 
         try:
         try:
             raise KeyError('foo')
             raise KeyError('foo')
@@ -960,7 +961,7 @@ class test_WorkController(AppCase):
             self.assertIn('KeyError', msg % args)
             self.assertIn('KeyError', msg % args)
 
 
     def test_on_timer_tick(self):
     def test_on_timer_tick(self):
-        worker = WorkController(concurrency=1, loglevel=10)
+        worker = self.app.WorkController(concurrency=1, loglevel=10)
 
 
         components.Timer(worker).on_timer_tick(30.0)
         components.Timer(worker).on_timer_tick(30.0)
         xargs = self.comp_logger.debug.call_args[0]
         xargs = self.comp_logger.debug.call_args[0]
@@ -974,7 +975,7 @@ class test_WorkController(AppCase):
         backend = Mock()
         backend = Mock()
         m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
         m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
                            kwargs={})
                            kwargs={})
-        task = Request.from_message(m, m.decode())
+        task = Request.from_message(m, m.decode(), app=self.app)
         worker._process_task(task)
         worker._process_task(task)
         self.assertEqual(worker.pool.apply_async.call_count, 1)
         self.assertEqual(worker.pool.apply_async.call_count, 1)
         worker.pool.stop()
         worker.pool.stop()
@@ -986,7 +987,7 @@ class test_WorkController(AppCase):
         backend = Mock()
         backend = Mock()
         m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
         m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
                            kwargs={})
                            kwargs={})
-        task = Request.from_message(m, m.decode())
+        task = Request.from_message(m, m.decode(), app=self.app)
         worker.steps = []
         worker.steps = []
         worker.blueprint.state = RUN
         worker.blueprint.state = RUN
         with self.assertRaises(KeyboardInterrupt):
         with self.assertRaises(KeyboardInterrupt):
@@ -1000,7 +1001,7 @@ class test_WorkController(AppCase):
         backend = Mock()
         backend = Mock()
         m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
         m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
                            kwargs={})
                            kwargs={})
-        task = Request.from_message(m, m.decode())
+        task = Request.from_message(m, m.decode(), app=self.app)
         worker.steps = []
         worker.steps = []
         worker.blueprint.state = RUN
         worker.blueprint.state = RUN
         with self.assertRaises(SystemExit):
         with self.assertRaises(SystemExit):
@@ -1014,7 +1015,7 @@ class test_WorkController(AppCase):
         backend = Mock()
         backend = Mock()
         m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
         m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
                            kwargs={})
                            kwargs={})
-        task = Request.from_message(m, m.decode())
+        task = Request.from_message(m, m.decode(), app=self.app)
         worker._process_task(task)
         worker._process_task(task)
         worker.pool.stop()
         worker.pool.stop()