Browse Source

Tests passing

Ask Solem 11 years ago
parent
commit
dbe733d234
37 changed files with 492 additions and 474 deletions
  1. 1 0
      celery/__init__.py
  2. 1 0
      celery/_state.py
  3. 1 3
      celery/app/control.py
  4. 1 2
      celery/apps/beat.py
  5. 6 8
      celery/beat.py
  6. 1 2
      celery/bin/amqp.py
  7. 3 3
      celery/events/cursesmon.py
  8. 2 3
      celery/loaders/base.py
  9. 1 1
      celery/schedules.py
  10. 1 1
      celery/tests/app/test_app.py
  11. 5 3
      celery/tests/app/test_beat.py
  12. 10 9
      celery/tests/app/test_loaders.py
  13. 8 10
      celery/tests/backends/test_amqp.py
  14. 7 6
      celery/tests/backends/test_backends.py
  15. 61 45
      celery/tests/backends/test_base.py
  16. 10 11
      celery/tests/backends/test_cache.py
  17. 18 19
      celery/tests/backends/test_database.py
  18. 7 7
      celery/tests/backends/test_mongodb.py
  19. 24 25
      celery/tests/backends/test_redis.py
  20. 13 13
      celery/tests/bin/test_beat.py
  21. 2 2
      celery/tests/bin/test_celeryd_detach.py
  22. 3 5
      celery/tests/bin/test_events.py
  23. 34 32
      celery/tests/bin/test_worker.py
  24. 2 18
      celery/tests/case.py
  25. 13 12
      celery/tests/compat_modules/test_sets.py
  26. 4 4
      celery/tests/events/test_cursesmon.py
  27. 6 9
      celery/tests/events/test_snapshot.py
  28. 3 3
      celery/tests/security/case.py
  29. 40 28
      celery/tests/security/test_security.py
  30. 27 28
      celery/tests/tasks/test_chord.py
  31. 5 5
      celery/tests/tasks/test_http.py
  32. 15 14
      celery/tests/tasks/test_result.py
  33. 33 29
      celery/tests/tasks/test_tasks.py
  34. 4 3
      celery/tests/worker/test_control.py
  35. 25 25
      celery/tests/worker/test_loops.py
  36. 54 46
      celery/tests/worker/test_request.py
  37. 41 40
      celery/tests/worker/test_worker.py

+ 1 - 0
celery/__init__.py

@@ -44,6 +44,7 @@ if STATICA_HACK:  # pragma: no cover
     # This is never executed, but tricks static analyzers (PyDev, PyCharm,
     # pylint, etc.) into knowing the types of these symbols, and what
     # they contain.
+    from celery.app import shared_task                   # noqa
     from celery.app.base import Celery                   # noqa
     from celery.app.utils import bugreport               # noqa
     from celery.app.task import Task                     # noqa

+ 1 - 0
celery/_state.py

@@ -55,6 +55,7 @@ def _get_current_app():
 C_STRICT_APP = os.environ.get('C_STRICT_APP')
 if os.environ.get('C_STRICT_APP'):  # pragma: no cover
     def get_current_app():
+        raise Exception('USES CURRENT APP')
         import traceback
         print('-- USES CURRENT_APP', file=sys.stderr)  # noqa+
         traceback.print_stack(file=sys.stderr)

+ 1 - 3
celery/app/control.py

@@ -16,8 +16,6 @@ from kombu.utils import cached_property
 
 from celery.exceptions import DuplicateNodenameWarning
 
-from . import app_or_default
-
 W_DUPNODE = """\
 Received multiple replies from node name {0!r}.
 Please make sure you give each node a unique nodename using the `-n` option.\
@@ -121,7 +119,7 @@ class Control(object):
     Mailbox = Mailbox
 
     def __init__(self, app=None):
-        self.app = app_or_default(app)
+        self.app = app
         self.mailbox = self.Mailbox('celery', type='fanout',
                                     accept=self.app.conf.CELERY_ACCEPT_CONTENT)
 

+ 1 - 2
celery/apps/beat.py

@@ -16,7 +16,6 @@ import socket
 import sys
 
 from celery import VERSION_BANNER, platforms, beat
-from celery.app import app_or_default
 from celery.utils.imports import qualname
 from celery.utils.log import LOG_LEVELS, get_logger
 from celery.utils.timeutils import humanize_seconds
@@ -44,7 +43,7 @@ class Beat(object):
                  scheduler_cls=None, redirect_stdouts=None,
                  redirect_stdouts_level=None, **kwargs):
         """Starts the beat task scheduler."""
-        self.app = app = app_or_default(app or self.app)
+        self.app = app = app or self.app
         self.loglevel = self._getopt('log_level', loglevel)
         self.logfile = self._getopt('log_file', logfile)
         self.schedule = self._getopt('schedule_filename', schedule)

+ 6 - 8
celery/beat.py

@@ -25,7 +25,6 @@ from . import __version__
 from . import platforms
 from . import signals
 from . import current_app
-from .app import app_or_default
 from .five import items, reraise, values
 from .schedules import maybe_schedule, crontab
 from .utils.imports import instantiate
@@ -135,7 +134,6 @@ class Scheduler(object):
     :keyword max_interval: see :attr:`max_interval`.
 
     """
-
     Entry = ScheduleEntry
 
     #: The schedule dict/shelve.
@@ -151,9 +149,9 @@ class Scheduler(object):
 
     logger = logger  # compat
 
-    def __init__(self, schedule=None, max_interval=None,
-                 app=None, Publisher=None, lazy=False, **kwargs):
-        app = self.app = app_or_default(app)
+    def __init__(self, app, schedule=None, max_interval=None,
+                 Publisher=None, lazy=False, **kwargs):
+        self.app = app
         self.data = maybe_promise({} if schedule is None else schedule)
         self.max_interval = (max_interval
                              or app.conf.CELERYBEAT_MAX_LOOP_INTERVAL
@@ -398,9 +396,9 @@ class PersistentScheduler(Scheduler):
 class Service(object):
     scheduler_cls = PersistentScheduler
 
-    def __init__(self, max_interval=None, schedule_filename=None,
-                 scheduler_cls=None, app=None):
-        app = self.app = app_or_default(app)
+    def __init__(self, app, max_interval=None, schedule_filename=None,
+                 scheduler_cls=None):
+        self.app = app
         self.max_interval = (max_interval
                              or app.conf.CELERYBEAT_MAX_LOOP_INTERVAL)
         self.scheduler_cls = scheduler_cls or self.scheduler_cls

+ 1 - 2
celery/bin/amqp.py

@@ -18,7 +18,6 @@ from itertools import count
 
 from amqp import Message
 
-from celery.app import app_or_default
 from celery.utils.functional import padlist
 
 from celery.bin.base import Command
@@ -328,7 +327,7 @@ class AMQPAdmin(object):
     Shell = AMQShell
 
     def __init__(self, *args, **kwargs):
-        self.app = app_or_default(kwargs.get('app'))
+        self.app = kwargs['app']
         self.out = kwargs.setdefault('out', sys.stderr)
         self.silent = kwargs.get('silent')
         self.args = args

+ 3 - 3
celery/events/cursesmon.py

@@ -56,8 +56,8 @@ class CursesMonitor(object):  # pragma: no cover
     greet = 'celery events {0}'.format(VERSION_BANNER)
     info_str = 'Info: '
 
-    def __init__(self, state, keymap=None, app=None):
-        self.app = app_or_default(app)
+    def __init__(self, state, app, keymap=None):
+        self.app = app
         self.keymap = keymap or self.keymap
         self.state = state
         default_keymap = {'J': self.move_selection_down,
@@ -521,7 +521,7 @@ def capture_events(app, state, display):  # pragma: no cover
 def evtop(app=None):  # pragma: no cover
     app = app_or_default(app)
     state = app.events.State()
-    display = CursesMonitor(state, app=app)
+    display = CursesMonitor(state, app)
     display.init_screen()
     refresher = DisplayThread(display)
     refresher.start()

+ 2 - 3
celery/loaders/base.py

@@ -64,9 +64,8 @@ class BaseLoader(object):
 
     _conf = None
 
-    def __init__(self, app=None, **kwargs):
-        from celery.app import app_or_default
-        self.app = app_or_default(app)
+    def __init__(self, app, **kwargs):
+        self.app = app
         self.task_modules = set()
 
     def now(self, utc=True):

+ 1 - 1
celery/schedules.py

@@ -529,7 +529,7 @@ class crontab(schedule):
                     other.day_of_week == self.day_of_week and
                     other.hour == self.hour and
                     other.minute == self.minute)
-        return other is self
+        return NotImplemented
 
     def __ne__(self, other):
         return not self.__eq__(other)

+ 1 - 1
celery/tests/app/test_app.py

@@ -513,7 +513,7 @@ class test_App(Case):
             def mail_admins(*args, **kwargs):
                 return args, kwargs
 
-        self.app.loader = Loader()
+        self.app.loader = Loader(app=self.app)
         self.app.conf.ADMINS = None
         self.assertFalse(self.app.mail_admins('Subject', 'Body'))
         self.app.conf.ADMINS = [('George Costanza', 'george@vandelay.com')]

+ 5 - 3
celery/tests/app/test_beat.py

@@ -379,7 +379,9 @@ class test_PersistentScheduler(AppCase):
         s._store.clear.assert_called_with()
 
     def test_get_schedule(self):
-        s = create_persistent_scheduler()[0](schedule_filename='schedule')
+        s = create_persistent_scheduler()[0](
+            schedule_filename='schedule', app=self.app,
+        )
         s._store = {'entries': {}}
         s.schedule = {'foo': 'bar'}
         self.assertDictEqual(s.schedule, {'foo': 'bar'})
@@ -455,7 +457,7 @@ class test_EmbeddedService(AppCase):
 
         from billiard.process import Process
 
-        s = beat.EmbeddedService()
+        s = beat.EmbeddedService(app=self.app)
         self.assertIsInstance(s, Process)
         self.assertIsInstance(s.service, beat.Service)
         s.service = MockService()
@@ -475,7 +477,7 @@ class test_EmbeddedService(AppCase):
         self.assertTrue(s._popen.terminated)
 
     def test_start_stop_threaded(self):
-        s = beat.EmbeddedService(thread=True)
+        s = beat.EmbeddedService(thread=True, app=self.app)
         from threading import Thread
         self.assertIsInstance(s, Thread)
         self.assertIsInstance(s.service, beat.Service)

+ 10 - 9
celery/tests/app/test_loaders.py

@@ -70,7 +70,8 @@ class test_LoaderBase(AppCase):
 
     def test_read_configuration_no_env(self):
         self.assertDictEqual(
-            base.BaseLoader().read_configuration('FOO_X_S_WE_WQ_Q_WE'),
+            base.BaseLoader(app=self.app).read_configuration(
+                'FOO_X_S_WE_WQ_Q_WE'),
             {},
         )
 
@@ -146,7 +147,7 @@ class test_LoaderBase(AppCase):
 
     def test_mail_attribute(self):
         from celery.utils import mail
-        loader = base.BaseLoader()
+        loader = base.BaseLoader(app=self.app)
         self.assertIs(loader.mail, mail)
 
     def test_cmdline_config_ValueError(self):
@@ -154,12 +155,12 @@ class test_LoaderBase(AppCase):
             self.loader.cmdline_config_parser(['broker.port=foobar'])
 
 
-class test_DefaultLoader(Case):
+class test_DefaultLoader(AppCase):
 
     @patch('celery.loaders.base.find_module')
     def test_read_configuration_not_a_package(self, find_module):
         find_module.side_effect = NotAPackage()
-        l = default.Loader()
+        l = default.Loader(app=self.app)
         with self.assertRaises(NotAPackage):
             l.read_configuration()
 
@@ -169,7 +170,7 @@ class test_DefaultLoader(Case):
         os.environ['CELERY_CONFIG_MODULE'] = 'celeryconfig.py'
         try:
             find_module.side_effect = NotAPackage()
-            l = default.Loader()
+            l = default.Loader(app=self.app)
             with self.assertRaises(NotAPackage):
                 l.read_configuration()
         finally:
@@ -179,7 +180,7 @@ class test_DefaultLoader(Case):
     def test_read_configuration_importerror(self, find_module):
         default.C_WNOCONF = True
         find_module.side_effect = ImportError()
-        l = default.Loader()
+        l = default.Loader(app=self.app)
         with self.assertWarnsRegex(NotConfigured, r'make sure it exists'):
             l.read_configuration()
         default.C_WNOCONF = False
@@ -198,7 +199,7 @@ class test_DefaultLoader(Case):
         prevconfig = sys.modules.get(configname)
         sys.modules[configname] = celeryconfig
         try:
-            l = default.Loader()
+            l = default.Loader(app=self.app)
             settings = l.read_configuration()
             self.assertTupleEqual(settings.CELERY_IMPORTS, ('os', 'sys'))
             settings = l.read_configuration()
@@ -209,7 +210,7 @@ class test_DefaultLoader(Case):
                 sys.modules[configname] = prevconfig
 
     def test_import_from_cwd(self):
-        l = default.Loader()
+        l = default.Loader(app=self.app)
         old_path = list(sys.path)
         try:
             sys.path.remove(os.getcwd())
@@ -234,7 +235,7 @@ class test_DefaultLoader(Case):
                 raise ImportError(name)
 
         with warnings.catch_warnings(record=True):
-            l = _Loader()
+            l = _Loader(app=self.app)
             self.assertFalse(l.configured)
             context_executed[0] = True
         self.assertTrue(context_executed[0])

+ 8 - 10
celery/tests/backends/test_amqp.py

@@ -10,9 +10,7 @@ from pickle import dumps, loads
 from billiard.einfo import ExceptionInfo
 from mock import patch
 
-from celery import current_app
 from celery import states
-from celery.app import app_or_default
 from celery.backends.amqp import AMQPBackend
 from celery.exceptions import TimeoutError
 from celery.five import Empty, Queue, range
@@ -31,7 +29,7 @@ class test_AMQPBackend(AppCase):
 
     def create_backend(self, **opts):
         opts = dict(dict(serializer='pickle', persistent=False), **opts)
-        return AMQPBackend(**opts)
+        return AMQPBackend(self.app, **opts)
 
     def test_mark_as_done(self):
         tb1 = self.create_backend()
@@ -107,7 +105,7 @@ class test_AMQPBackend(AppCase):
             iterations[0] += 1
             raise KeyError('foo')
 
-        backend = AMQPBackend()
+        backend = AMQPBackend(self.app)
         from celery.app.amqp import TaskProducer
         prod, TaskProducer.publish = TaskProducer.publish, publish
         try:
@@ -172,7 +170,7 @@ class test_AMQPBackend(AppCase):
         class MockBackend(AMQPBackend):
             Queue = MockBinding
 
-        backend = MockBackend()
+        backend = MockBackend(self.app)
         backend._republish = Mock()
 
         yield results, backend, Message
@@ -251,7 +249,7 @@ class test_AMQPBackend(AppCase):
                 pass
 
         b = self.create_backend()
-        with current_app.pool.acquire_channel(block=False) as (_, channel):
+        with self.app.pool.acquire_channel(block=False) as (_, channel):
             binding = b._create_binding(uuid())
             consumer = b.Consumer(channel, binding, no_ack=True)
             with self.assertRaises(socket.timeout):
@@ -296,14 +294,14 @@ class test_AMQPBackend(AppCase):
             def Consumer(*args, **kwargs):
                 raise KeyError('foo')
 
-        b = Backend()
+        b = Backend(self.app)
         with self.assertRaises(KeyError):
             next(b.get_many(['id1']))
 
     def test_get_many_raises_inner_block(self):
         with patch('kombu.connection.Connection.drain_events') as drain:
             drain.side_effect = KeyError('foo')
-            b = AMQPBackend()
+            b = AMQPBackend(self.app)
             with self.assertRaises(KeyError):
                 next(b.get_many(['id1']))
 
@@ -314,13 +312,13 @@ class test_AMQPBackend(AppCase):
                 drain.side_effect = ValueError()
                 raise KeyError('foo')
             drain.side_effect = se
-            b = AMQPBackend()
+            b = AMQPBackend(self.app)
             with self.assertRaises(ValueError):
                 next(b.consume('id1'))
 
     def test_no_expires(self):
         b = self.create_backend(expires=None)
-        app = app_or_default()
+        app = self.app
         prev = app.conf.CELERY_TASK_RESULT_EXPIRES
         app.conf.CELERY_TASK_RESULT_EXPIRES = None
         try:

+ 7 - 6
celery/tests/backends/test_backends.py

@@ -2,21 +2,22 @@ from __future__ import absolute_import
 
 from mock import patch
 
-from celery import current_app
 from celery import backends
 from celery.backends.amqp import AMQPBackend
 from celery.backends.cache import CacheBackend
-from celery.tests.case import Case
+from celery.tests.case import AppCase
 
 
-class test_backends(Case):
+class test_backends(AppCase):
 
     def test_get_backend_aliases(self):
         expects = [('amqp', AMQPBackend),
                    ('cache', CacheBackend)]
         for expect_name, expect_cls in expects:
-            self.assertIsInstance(backends.get_backend_cls(expect_name)(),
-                                  expect_cls)
+            self.assertIsInstance(
+                backends.get_backend_cls(expect_name)(app=self.app),
+                expect_cls,
+            )
 
     def test_get_backend_cache(self):
         backends.get_backend_cls.clear()
@@ -32,7 +33,7 @@ class test_backends(Case):
             backends.get_backend_cls('fasodaopjeqijwqe')
 
     def test_default_backend(self):
-        self.assertEqual(backends.default_backend, current_app.backend)
+        self.assertEqual(backends.default_backend, self.app.backend)
 
     def test_backend_by_url(self, url='redis://localhost/1'):
         from celery.backends.redis import RedisBackend

+ 61 - 45
celery/tests/backends/test_base.py

@@ -7,7 +7,6 @@ from contextlib import contextmanager
 from mock import Mock, patch
 from nose import SkipTest
 
-from celery import current_app
 from celery.exceptions import ChordError
 from celery.five import items, range
 from celery.result import AsyncResult, GroupResult
@@ -40,10 +39,9 @@ else:
 Unpickleable = subclass_exception('Unpickleable', KeyError, 'foo.module')
 Impossible = subclass_exception('Impossible', object, 'foo.module')
 Lookalike = subclass_exception('Lookalike', wrapobject, 'foo.module')
-b = BaseBackend()
 
 
-class test_serialization(Case):
+class test_serialization(AppCase):
 
     def test_create_exception_cls(self):
         self.assertTrue(serialization.create_exception_cls('FooError', 'm'))
@@ -51,27 +49,32 @@ class test_serialization(Case):
                                                            KeyError))
 
 
-class test_BaseBackend_interface(Case):
+class test_BaseBackend_interface(AppCase):
+
+    def setup(self):
+        self.b = BaseBackend(self.app)
 
     def test__forget(self):
         with self.assertRaises(NotImplementedError):
-            b._forget('SOMExx-N0Nex1stant-IDxx-')
+            self.b._forget('SOMExx-N0Nex1stant-IDxx-')
 
     def test_forget(self):
         with self.assertRaises(NotImplementedError):
-            b.forget('SOMExx-N0nex1stant-IDxx-')
+            self.b.forget('SOMExx-N0nex1stant-IDxx-')
 
     def test_on_chord_part_return(self):
-        b.on_chord_part_return(None)
+        self.b.on_chord_part_return(None)
 
     def test_on_chord_apply(self, unlock='celery.chord_unlock'):
-        p, current_app.tasks[unlock] = current_app.tasks.get(unlock), Mock()
+        p, self.app.tasks[unlock] = self.app.tasks.get(unlock), Mock()
         try:
-            b.on_chord_apply('dakj221', 'sdokqweok',
-                             result=[AsyncResult(x) for x in [1, 2, 3]])
-            self.assertTrue(current_app.tasks[unlock].apply_async.call_count)
+            self.b.on_chord_apply(
+                'dakj221', 'sdokqweok',
+                result=[self.app.AsyncResult(x) for x in [1, 2, 3]],
+            )
+            self.assertTrue(self.app.tasks[unlock].apply_async.call_count)
         finally:
-            current_app.tasks[unlock] = p
+            self.app.tasks[unlock] = p
 
 
 class test_exception_pickle(Case):
@@ -93,19 +96,22 @@ class test_exception_pickle(Case):
         self.assertIsNone(fnpe(Impossible()))
 
 
-class test_prepare_exception(Case):
+class test_prepare_exception(AppCase):
+
+    def setup(self):
+        self.b = BaseBackend(self.app)
 
     def test_unpickleable(self):
-        x = b.prepare_exception(Unpickleable(1, 2, 'foo'))
+        x = self.b.prepare_exception(Unpickleable(1, 2, 'foo'))
         self.assertIsInstance(x, KeyError)
-        y = b.exception_to_python(x)
+        y = self.b.exception_to_python(x)
         self.assertIsInstance(y, KeyError)
 
     def test_impossible(self):
-        x = b.prepare_exception(Impossible())
+        x = self.b.prepare_exception(Impossible())
         self.assertIsInstance(x, UnpickleableExceptionWrapper)
         self.assertTrue(str(x))
-        y = b.exception_to_python(x)
+        y = self.b.exception_to_python(x)
         self.assertEqual(y.__class__.__name__, 'Impossible')
         if sys.version_info < (2, 5):
             self.assertTrue(y.__class__.__module__)
@@ -113,18 +119,18 @@ class test_prepare_exception(Case):
             self.assertEqual(y.__class__.__module__, 'foo.module')
 
     def test_regular(self):
-        x = b.prepare_exception(KeyError('baz'))
+        x = self.b.prepare_exception(KeyError('baz'))
         self.assertIsInstance(x, KeyError)
-        y = b.exception_to_python(x)
+        y = self.b.exception_to_python(x)
         self.assertIsInstance(y, KeyError)
 
 
 class KVBackend(KeyValueStoreBackend):
     mget_returns_dict = False
 
-    def __init__(self, *args, **kwargs):
+    def __init__(self, app, *args, **kwargs):
         self.db = {}
-        super(KVBackend, self).__init__()
+        super(KVBackend, self).__init__(app)
 
     def get(self, key):
         return self.db.get(key)
@@ -160,17 +166,17 @@ class DictBackend(BaseBackend):
         self._data.pop(group_id, None)
 
 
-class test_BaseBackend_dict(Case):
+class test_BaseBackend_dict(AppCase):
 
-    def setUp(self):
-        self.b = DictBackend()
+    def setup(self):
+        self.b = DictBackend(app=self.app)
 
     def test_delete_group(self):
         self.b.delete_group('can-delete')
         self.assertNotIn('can-delete', self.b._data)
 
     def test_prepare_exception_json(self):
-        x = DictBackend(serializer='json')
+        x = DictBackend(self.app, serializer='json')
         e = x.prepare_exception(KeyError('foo'))
         self.assertIn('exc_type', e)
         e = x.exception_to_python(e)
@@ -178,13 +184,13 @@ class test_BaseBackend_dict(Case):
         self.assertEqual(str(e), "'foo'")
 
     def test_save_group(self):
-        b = BaseBackend()
+        b = BaseBackend(self.app)
         b._save_group = Mock()
         b.save_group('foofoo', 'xxx')
         b._save_group.assert_called_with('foofoo', 'xxx')
 
     def test_forget_interface(self):
-        b = BaseBackend()
+        b = BaseBackend(self.app)
         with self.assertRaises(NotImplementedError):
             b.forget('foo')
 
@@ -230,7 +236,7 @@ class test_BaseBackend_dict(Case):
 class test_KeyValueStoreBackend(AppCase):
 
     def setup(self):
-        self.b = KVBackend()
+        self.b = KVBackend(app=self.app)
 
     def test_on_chord_part_return(self):
         assert not self.b.implements_incr
@@ -330,7 +336,9 @@ class test_KeyValueStoreBackend(AppCase):
 
     def test_chord_part_return_join_raises_task(self):
         with self._chord_part_context(self.b) as (task, deps, callback):
-            deps._failed_join_report = lambda: iter([AsyncResult('culprit')])
+            deps._failed_join_report = lambda: iter([
+                self.app.AsyncResult('culprit'),
+            ])
             deps.join_native.side_effect = KeyError('foo')
             self.b.on_chord_part_return(task)
             self.assertTrue(self.b.fail_from_current_stack.called)
@@ -340,15 +348,21 @@ class test_KeyValueStoreBackend(AppCase):
             self.assertIn('Dependency culprit raised', str(exc))
 
     def test_restore_group_from_json(self):
-        b = KVBackend(serializer='json')
-        g = GroupResult('group_id', [AsyncResult('a'), AsyncResult('b')])
+        b = KVBackend(serializer='json', app=self.app)
+        g = self.app.GroupResult(
+            'group_id',
+            [self.app.AsyncResult('a'), self.app.AsyncResult('b')],
+        )
         b._save_group(g.id, g)
         g2 = b._restore_group(g.id)['result']
         self.assertEqual(g2, g)
 
     def test_restore_group_from_pickle(self):
-        b = KVBackend(serializer='pickle')
-        g = GroupResult('group_id', [AsyncResult('a'), AsyncResult('b')])
+        b = KVBackend(serializer='pickle', app=self.app)
+        g = self.app.GroupResult(
+            'group_id',
+            [self.app.AsyncResult('a'), self.app.AsyncResult('b')],
+        )
         b._save_group(g.id, g)
         g2 = b._restore_group(g.id)['result']
         self.assertEqual(g2, g)
@@ -367,7 +381,9 @@ class test_KeyValueStoreBackend(AppCase):
 
     def test_save_restore_delete_group(self):
         tid = uuid()
-        tsr = GroupResult(tid, [AsyncResult(uuid()) for _ in range(10)])
+        tsr = self.app.GroupResult(
+            tid, [self.app.AsyncResult(uuid()) for _ in range(10)],
+        )
         self.b.save_group(tid, tsr)
         self.b.restore_group(tid)
         self.assertEqual(self.b.restore_group(tid), tsr)
@@ -378,41 +394,41 @@ class test_KeyValueStoreBackend(AppCase):
         self.assertIsNone(self.b.restore_group('xxx-nonexistant'))
 
 
-class test_KeyValueStoreBackend_interface(Case):
+class test_KeyValueStoreBackend_interface(AppCase):
 
     def test_get(self):
         with self.assertRaises(NotImplementedError):
-            KeyValueStoreBackend().get('a')
+            KeyValueStoreBackend(self.app).get('a')
 
     def test_set(self):
         with self.assertRaises(NotImplementedError):
-            KeyValueStoreBackend().set('a', 1)
+            KeyValueStoreBackend(self.app).set('a', 1)
 
     def test_incr(self):
         with self.assertRaises(NotImplementedError):
-            KeyValueStoreBackend().incr('a')
+            KeyValueStoreBackend(self.app).incr('a')
 
     def test_cleanup(self):
-        self.assertFalse(KeyValueStoreBackend().cleanup())
+        self.assertFalse(KeyValueStoreBackend(self.app).cleanup())
 
     def test_delete(self):
         with self.assertRaises(NotImplementedError):
-            KeyValueStoreBackend().delete('a')
+            KeyValueStoreBackend(self.app).delete('a')
 
     def test_mget(self):
         with self.assertRaises(NotImplementedError):
-            KeyValueStoreBackend().mget(['a'])
+            KeyValueStoreBackend(self.app).mget(['a'])
 
     def test_forget(self):
         with self.assertRaises(NotImplementedError):
-            KeyValueStoreBackend().forget('a')
+            KeyValueStoreBackend(self.app).forget('a')
 
 
-class test_DisabledBackend(Case):
+class test_DisabledBackend(AppCase):
 
     def test_store_result(self):
-        DisabledBackend().store_result()
+        DisabledBackend(self.app).store_result()
 
     def test_is_disabled(self):
         with self.assertRaises(NotImplementedError):
-            DisabledBackend().get_status('foo')
+            DisabledBackend(self.app).get_status('foo')

+ 10 - 11
celery/tests/backends/test_cache.py

@@ -8,7 +8,6 @@ from contextlib import contextmanager
 from kombu.utils.encoding import str_to_bytes
 from mock import Mock, patch
 
-from celery import current_app
 from celery import states
 from celery.backends.cache import CacheBackend, DummyClient
 from celery.exceptions import ImproperlyConfigured
@@ -67,13 +66,13 @@ class test_CacheBackend(AppCase):
             self.assertIsInstance(self.tb.get_result(self.tid), KeyError)
 
     def test_on_chord_apply(self):
-        tb = CacheBackend(backend='memory://')
+        tb = CacheBackend(backend='memory://', app=self.app)
         gid, res = uuid(), [AsyncResult(uuid()) for _ in range(3)]
         tb.on_chord_apply(gid, {}, result=res)
 
     @patch('celery.result.GroupResult')
     def test_on_chord_part_return(self, setresult):
-        tb = CacheBackend(backend='memory://')
+        tb = CacheBackend(backend='memory://', app=self.app)
 
         deps = Mock()
         deps.__len__ = Mock()
@@ -82,7 +81,7 @@ class test_CacheBackend(AppCase):
         task = Mock()
         task.name = 'foobarbaz'
         try:
-            current_app.tasks['foobarbaz'] = task
+            self.app.tasks['foobarbaz'] = task
             task.request.chord = subtask(task)
 
             gid, res = uuid(), [AsyncResult(uuid()) for _ in range(3)]
@@ -98,7 +97,7 @@ class test_CacheBackend(AppCase):
             deps.delete.assert_called_with()
 
         finally:
-            current_app.tasks.pop('foobarbaz')
+            self.app.tasks.pop('foobarbaz')
 
     def test_mget(self):
         self.tb.set('foo', 1)
@@ -117,12 +116,12 @@ class test_CacheBackend(AppCase):
         self.tb.process_cleanup()
 
     def test_expires_as_int(self):
-        tb = CacheBackend(backend='memory://', expires=10)
+        tb = CacheBackend(backend='memory://', expires=10, app=self.app)
         self.assertEqual(tb.expires, 10)
 
     def test_unknown_backend_raises_ImproperlyConfigured(self):
         with self.assertRaises(ImproperlyConfigured):
-            CacheBackend(backend='unknown://')
+            CacheBackend(backend='unknown://', app=self.app)
 
 
 class MyMemcachedStringEncodingError(Exception):
@@ -218,7 +217,7 @@ class test_memcache_key(AppCase, MockCacheMixin):
                     from celery.backends import cache
                     cache._imp = [None]
                     task_id, result = string(uuid()), 42
-                    b = cache.CacheBackend(backend='memcache')
+                    b = cache.CacheBackend(backend='memcache', app=self.app)
                     b.store_result(task_id, result, status=states.SUCCESS)
                     self.assertEqual(b.get_result(task_id), result)
 
@@ -229,7 +228,7 @@ class test_memcache_key(AppCase, MockCacheMixin):
                     from celery.backends import cache
                     cache._imp = [None]
                     task_id, result = str_to_bytes(uuid()), 42
-                    b = cache.CacheBackend(backend='memcache')
+                    b = cache.CacheBackend(backend='memcache', app=self.app)
                     b.store_result(task_id, result, status=states.SUCCESS)
                     self.assertEqual(b.get_result(task_id), result)
 
@@ -239,7 +238,7 @@ class test_memcache_key(AppCase, MockCacheMixin):
                 from celery.backends import cache
                 cache._imp = [None]
                 task_id, result = string(uuid()), 42
-                b = cache.CacheBackend(backend='memcache')
+                b = cache.CacheBackend(backend='memcache', app=self.app)
                 b.store_result(task_id, result, status=states.SUCCESS)
                 self.assertEqual(b.get_result(task_id), result)
 
@@ -249,6 +248,6 @@ class test_memcache_key(AppCase, MockCacheMixin):
                 from celery.backends import cache
                 cache._imp = [None]
                 task_id, result = str_to_bytes(uuid()), 42
-                b = cache.CacheBackend(backend='memcache')
+                b = cache.CacheBackend(backend='memcache', app=self.app)
                 b.store_result(task_id, result, status=states.SUCCESS)
                 self.assertEqual(b.get_result(task_id), result)

+ 18 - 19
celery/tests/backends/test_database.py

@@ -6,13 +6,12 @@ from nose import SkipTest
 from pickle import loads, dumps
 
 from celery import states
-from celery.app import app_or_default
 from celery.exceptions import ImproperlyConfigured
 from celery.result import AsyncResult
 from celery.utils import uuid
 
 from celery.tests.case import (
-    Case,
+    AppCase,
     mask_modules,
     skip_if_pypy,
     skip_if_jython,
@@ -33,11 +32,11 @@ class SomeClass(object):
         self.data = data
 
 
-class test_DatabaseBackend(Case):
+class test_DatabaseBackend(AppCase):
 
     @skip_if_pypy
     @skip_if_jython
-    def setUp(self):
+    def setup(self):
         if DatabaseBackend is None:
             raise SkipTest('sqlalchemy not installed')
 
@@ -62,20 +61,20 @@ class test_DatabaseBackend(Case):
                 _sqlalchemy_installed()
 
     def test_missing_dburi_raises_ImproperlyConfigured(self):
-        conf = app_or_default().conf
+        conf = self.app.conf
         prev, conf.CELERY_RESULT_DBURI = conf.CELERY_RESULT_DBURI, None
         try:
             with self.assertRaises(ImproperlyConfigured):
-                DatabaseBackend()
+                DatabaseBackend(app=self.app)
         finally:
             conf.CELERY_RESULT_DBURI = prev
 
     def test_missing_task_id_is_PENDING(self):
-        tb = DatabaseBackend()
+        tb = DatabaseBackend(app=self.app)
         self.assertEqual(tb.get_status('xxx-does-not-exist'), states.PENDING)
 
     def test_missing_task_meta_is_dict_with_pending(self):
-        tb = DatabaseBackend()
+        tb = DatabaseBackend(app=self.app)
         self.assertDictContainsSubset({
             'status': states.PENDING,
             'task_id': 'xxx-does-not-exist-at-all',
@@ -84,7 +83,7 @@ class test_DatabaseBackend(Case):
         }, tb.get_task_meta('xxx-does-not-exist-at-all'))
 
     def test_mark_as_done(self):
-        tb = DatabaseBackend()
+        tb = DatabaseBackend(app=self.app)
 
         tid = uuid()
 
@@ -96,7 +95,7 @@ class test_DatabaseBackend(Case):
         self.assertEqual(tb.get_result(tid), 42)
 
     def test_is_pickled(self):
-        tb = DatabaseBackend()
+        tb = DatabaseBackend(app=self.app)
 
         tid2 = uuid()
         result = {'foo': 'baz', 'bar': SomeClass(12345)}
@@ -107,19 +106,19 @@ class test_DatabaseBackend(Case):
         self.assertEqual(rindb.get('bar').data, 12345)
 
     def test_mark_as_started(self):
-        tb = DatabaseBackend()
+        tb = DatabaseBackend(app=self.app)
         tid = uuid()
         tb.mark_as_started(tid)
         self.assertEqual(tb.get_status(tid), states.STARTED)
 
     def test_mark_as_revoked(self):
-        tb = DatabaseBackend()
+        tb = DatabaseBackend(app=self.app)
         tid = uuid()
         tb.mark_as_revoked(tid)
         self.assertEqual(tb.get_status(tid), states.REVOKED)
 
     def test_mark_as_retry(self):
-        tb = DatabaseBackend()
+        tb = DatabaseBackend(app=self.app)
         tid = uuid()
         try:
             raise KeyError('foo')
@@ -132,7 +131,7 @@ class test_DatabaseBackend(Case):
             self.assertEqual(tb.get_traceback(tid), trace)
 
     def test_mark_as_failure(self):
-        tb = DatabaseBackend()
+        tb = DatabaseBackend(app=self.app)
 
         tid3 = uuid()
         try:
@@ -146,7 +145,7 @@ class test_DatabaseBackend(Case):
             self.assertEqual(tb.get_traceback(tid3), trace)
 
     def test_forget(self):
-        tb = DatabaseBackend(backend='memory://')
+        tb = DatabaseBackend(backend='memory://', app=self.app)
         tid = uuid()
         tb.mark_as_done(tid, {'foo': 'bar'})
         tb.mark_as_done(tid, {'foo': 'bar'})
@@ -155,15 +154,15 @@ class test_DatabaseBackend(Case):
         self.assertIsNone(x.result)
 
     def test_process_cleanup(self):
-        tb = DatabaseBackend()
+        tb = DatabaseBackend(app=self.app)
         tb.process_cleanup()
 
     def test_reduce(self):
-        tb = DatabaseBackend()
+        tb = DatabaseBackend(app=self.app)
         self.assertTrue(loads(dumps(tb)))
 
     def test_save__restore__delete_group(self):
-        tb = DatabaseBackend()
+        tb = DatabaseBackend(app=self.app)
 
         tid = uuid()
         res = {'something': 'special'}
@@ -178,7 +177,7 @@ class test_DatabaseBackend(Case):
         self.assertIsNone(tb.restore_group('xxx-nonexisting-id'))
 
     def test_cleanup(self):
-        tb = DatabaseBackend()
+        tb = DatabaseBackend(app=self.app)
         for i in range(10):
             tb.mark_as_done(uuid(), 42)
             tb.save_group(uuid(), {'foo': 'bar'})

+ 7 - 7
celery/tests/backends/test_mongodb.py

@@ -26,7 +26,7 @@ MONGODB_COLLECTION = 'collection1'
 
 class test_MongoBackend(AppCase):
 
-    def setUp(self):
+    def setup(self):
         if pymongo is None:
             raise SkipTest('pymongo is not installed.')
 
@@ -36,9 +36,9 @@ class test_MongoBackend(AppCase):
         R['Binary'], module.Binary = module.Binary, Mock()
         R['datetime'], datetime.datetime = datetime.datetime, Mock()
 
-        self.backend = MongoBackend()
+        self.backend = MongoBackend(app=self.app)
 
-    def tearDown(self):
+    def teardown(self):
         MongoBackend.encode = self._reset['encode']
         MongoBackend.decode = self._reset['decode']
         module.Binary = self._reset['Binary']
@@ -53,7 +53,7 @@ class test_MongoBackend(AppCase):
         prev, module.pymongo = module.pymongo, None
         try:
             with self.assertRaises(ImproperlyConfigured):
-                MongoBackend()
+                MongoBackend(app=self.app)
         finally:
             module.pymongo = prev
 
@@ -69,14 +69,14 @@ class test_MongoBackend(AppCase):
         MongoBackend(app=celery)
 
     def test_restore_group_no_entry(self):
-        x = MongoBackend()
+        x = MongoBackend(app=self.app)
         x.collection = Mock()
         fo = x.collection.find_one = Mock()
         fo.return_value = None
         self.assertIsNone(x._restore_group('1f3fab'))
 
     def test_reduce(self):
-        x = MongoBackend()
+        x = MongoBackend(app=self.app)
         self.assertTrue(loads(dumps(x)))
 
     def test_get_connection_connection_exists(self):
@@ -311,7 +311,7 @@ class test_MongoBackend(AppCase):
         mock_collection.assert_called_once()
 
     def test_get_database_authfailure(self):
-        x = MongoBackend()
+        x = MongoBackend(app=self.app)
         x._get_connection = Mock()
         conn = x._get_connection.return_value = {}
         db = conn[x.mongodb_database] = Mock()

+ 24 - 25
celery/tests/backends/test_redis.py

@@ -8,7 +8,6 @@ from pickle import loads, dumps
 
 from kombu.utils import cached_property, uuid
 
-from celery import current_app
 from celery import states
 from celery.datastructures import AttributeDict
 from celery.exceptions import ImproperlyConfigured
@@ -16,7 +15,7 @@ from celery.result import AsyncResult
 from celery.task import subtask
 from celery.utils.timeutils import timedelta_seconds
 
-from celery.tests.case import Case
+from celery.tests.case import AppCase
 
 
 class Redis(object):
@@ -65,7 +64,7 @@ class redis(object):
             pass
 
 
-class test_RedisBackend(Case):
+class test_RedisBackend(AppCase):
 
     def get_backend(self):
         from celery.backends import redis
@@ -75,7 +74,7 @@ class test_RedisBackend(Case):
 
         return RedisBackend
 
-    def setUp(self):
+    def setup(self):
         self.Backend = self.get_backend()
 
         class MockBackend(self.Backend):
@@ -89,7 +88,7 @@ class test_RedisBackend(Case):
     def test_reduce(self):
         try:
             from celery.backends.redis import RedisBackend
-            x = RedisBackend()
+            x = RedisBackend(app=self.app)
             self.assertTrue(loads(dumps(x)))
         except ImportError:
             raise SkipTest('redis not installed')
@@ -97,10 +96,10 @@ class test_RedisBackend(Case):
     def test_no_redis(self):
         self.MockBackend.redis = None
         with self.assertRaises(ImproperlyConfigured):
-            self.MockBackend()
+            self.MockBackend(app=self.app)
 
     def test_url(self):
-        x = self.MockBackend('redis://foobar//1')
+        x = self.MockBackend('redis://foobar//1', app=self.app)
         self.assertEqual(x.host, 'foobar')
         self.assertEqual(x.db, '1')
 
@@ -108,54 +107,54 @@ class test_RedisBackend(Case):
         conf = AttributeDict({'CELERY_RESULT_SERIALIZER': 'json',
                               'CELERY_MAX_CACHED_RESULTS': 1,
                               'CELERY_TASK_RESULT_EXPIRES': None})
-        prev, current_app.conf = current_app.conf, conf
+        prev, self.app.conf = self.app.conf, conf
         try:
-            self.MockBackend()
+            self.MockBackend(app=self.app)
         finally:
-            current_app.conf = prev
+            self.app.conf = prev
 
     def test_expires_defaults_to_config(self):
-        conf = current_app.conf
+        conf = self.app.conf
         prev = conf.CELERY_TASK_RESULT_EXPIRES
         conf.CELERY_TASK_RESULT_EXPIRES = 10
         try:
-            b = self.Backend(expires=None)
+            b = self.Backend(expires=None, app=self.app)
             self.assertEqual(b.expires, 10)
         finally:
             conf.CELERY_TASK_RESULT_EXPIRES = prev
 
     def test_expires_is_int(self):
-        b = self.Backend(expires=48)
+        b = self.Backend(expires=48, app=self.app)
         self.assertEqual(b.expires, 48)
 
     def test_expires_is_None(self):
-        b = self.Backend(expires=None)
+        b = self.Backend(expires=None, app=self.app)
         self.assertEqual(b.expires, timedelta_seconds(
-            current_app.conf.CELERY_TASK_RESULT_EXPIRES))
+            self.app.conf.CELERY_TASK_RESULT_EXPIRES))
 
     def test_expires_is_timedelta(self):
-        b = self.Backend(expires=timedelta(minutes=1))
+        b = self.Backend(expires=timedelta(minutes=1), app=self.app)
         self.assertEqual(b.expires, 60)
 
     def test_on_chord_apply(self):
-        self.Backend().on_chord_apply(
+        self.Backend(app=self.app).on_chord_apply(
             'group_id', {},
             result=[AsyncResult(x) for x in [1, 2, 3]],
         )
 
     def test_mget(self):
-        b = self.MockBackend()
+        b = self.MockBackend(app=self.app)
         self.assertTrue(b.mget(['a', 'b', 'c']))
         b.client.mget.assert_called_with(['a', 'b', 'c'])
 
     def test_set_no_expire(self):
-        b = self.MockBackend()
+        b = self.MockBackend(app=self.app)
         b.expires = None
         b.set('foo', 'bar')
 
     @patch('celery.result.GroupResult')
     def test_on_chord_part_return(self, setresult):
-        b = self.MockBackend()
+        b = self.MockBackend(app=self.app)
         deps = Mock()
         deps.__len__ = Mock()
         deps.__len__.return_value = 10
@@ -164,7 +163,7 @@ class test_RedisBackend(Case):
         task = Mock()
         task.name = 'foobarbaz'
         try:
-            current_app.tasks['foobarbaz'] = task
+            self.app.tasks['foobarbaz'] = task
             task.request.chord = subtask(task)
             task.request.group = 'group_id'
 
@@ -178,13 +177,13 @@ class test_RedisBackend(Case):
 
             self.assertTrue(b.client.expire.call_count)
         finally:
-            current_app.tasks.pop('foobarbaz')
+            self.app.tasks.pop('foobarbaz')
 
     def test_process_cleanup(self):
-        self.Backend().process_cleanup()
+        self.Backend(app=self.app).process_cleanup()
 
     def test_get_set_forget(self):
-        b = self.Backend()
+        b = self.Backend(app=self.app)
         tid = uuid()
         b.store_result(tid, 42, states.SUCCESS)
         self.assertEqual(b.get_status(tid), states.SUCCESS)
@@ -193,7 +192,7 @@ class test_RedisBackend(Case):
         self.assertEqual(b.get_status(tid), states.PENDING)
 
     def test_set_expires(self):
-        b = self.Backend(expires=512)
+        b = self.Backend(expires=512, app=self.app)
         tid = uuid()
         key = b.get_key_for_task(tid)
         b.store_result(tid, 42, states.SUCCESS)

+ 13 - 13
celery/tests/bin/test_beat.py

@@ -10,7 +10,6 @@ from mock import patch
 
 from celery import beat
 from celery import platforms
-from celery.app import app_or_default
 from celery.bin import beat as beat_bin
 from celery.apps import beat as beatapp
 
@@ -64,10 +63,10 @@ class MockBeat3(beatapp.Beat):
 class test_Beat(AppCase):
 
     def test_loglevel_string(self):
-        b = beatapp.Beat(loglevel='DEBUG')
+        b = beatapp.Beat(app=self.app, loglevel='DEBUG')
         self.assertEqual(b.loglevel, logging.DEBUG)
 
-        b2 = beatapp.Beat(loglevel=logging.DEBUG)
+        b2 = beatapp.Beat(app=self.app, loglevel=logging.DEBUG)
         self.assertEqual(b2.loglevel, logging.DEBUG)
 
     def test_colorize(self):
@@ -80,15 +79,15 @@ class test_Beat(AppCase):
         self.assertEqual(app.log.setup.call_args[1]['colorize'], False)
 
     def test_init_loader(self):
-        b = beatapp.Beat()
+        b = beatapp.Beat(app=self.app)
         b.init_loader()
 
     def test_process_title(self):
-        b = beatapp.Beat()
+        b = beatapp.Beat(app=self.app)
         b.set_process_title()
 
     def test_run(self):
-        b = MockBeat2()
+        b = MockBeat2(app=self.app)
         MockService.started = False
         b.run()
         self.assertTrue(MockService.started)
@@ -109,8 +108,8 @@ class test_Beat(AppCase):
             platforms.signals = p
 
     def test_install_sync_handler(self):
-        b = beatapp.Beat()
-        clock = MockService()
+        b = beatapp.Beat(app=self.app)
+        clock = MockService(app=self.app)
         MockService.in_sync = False
         handlers = self.psig(b.install_sync_handler, clock)
         with self.assertRaises(SystemExit):
@@ -124,7 +123,7 @@ class test_Beat(AppCase):
             delattr(sys.stdout, 'logger')
         except AttributeError:
             pass
-        b = beatapp.Beat()
+        b = beatapp.Beat(app=self.app)
         b.redirect_stdouts = False
         b.app.log.__class__._setup = False
         b.setup_logging()
@@ -134,14 +133,15 @@ class test_Beat(AppCase):
     @redirect_stdouts
     @patch('celery.apps.beat.logger')
     def test_logs_errors(self, logger, stdout, stderr):
-        b = MockBeat3(socket_timeout=None)
+        b = MockBeat3(app=self.app, socket_timeout=None)
         b.start_scheduler()
         self.assertTrue(logger.critical.called)
 
     @redirect_stdouts
     @patch('celery.platforms.create_pidlock')
     def test_use_pidfile(self, create_pidlock, stdout, stderr):
-        b = MockBeat2(pidfile='pidfilelockfilepid', socket_timeout=None)
+        b = MockBeat2(app=self.app, pidfile='pidfilelockfilepid',
+                      socket_timeout=None)
         b.start_scheduler()
         self.assertTrue(create_pidlock.called)
 
@@ -184,13 +184,13 @@ class test_div(AppCase):
 
     def test_detach(self):
         cmd = beat_bin.beat()
-        cmd.app = app_or_default()
+        cmd.app = self.app
         cmd.run(detach=True)
         self.assertTrue(MockDaemonContext.opened)
         self.assertTrue(MockDaemonContext.closed)
 
     def test_parse_options(self):
         cmd = beat_bin.beat()
-        cmd.app = app_or_default()
+        cmd.app = self.app
         options, args = cmd.parse_options('celery beat', ['-s', 'foo'])
         self.assertEqual(options.schedule, 'foo')

+ 2 - 2
celery/tests/bin/test_celeryd_detach.py

@@ -2,7 +2,7 @@ from __future__ import absolute_import
 
 from mock import Mock, patch
 
-from celery import current_app
+from celery.platforms import IS_WINDOWS
 from celery.bin.celeryd_detach import (
     detach,
     detached_celeryd,
@@ -12,7 +12,7 @@ from celery.bin.celeryd_detach import (
 from celery.tests.case import Case, override_stdouts
 
 
-if not current_app.IS_WINDOWS:
+if not IS_WINDOWS:
     class test_detached(Case):
 
         @patch('celery.bin.celeryd_detach.detached')

+ 3 - 5
celery/tests/bin/test_events.py

@@ -3,10 +3,9 @@ from __future__ import absolute_import
 from nose import SkipTest
 from mock import patch as mpatch
 
-from celery.app import app_or_default
 from celery.bin import events
 
-from celery.tests.case import Case, _old_patch as patch
+from celery.tests.case import AppCase, _old_patch as patch
 
 
 class MockCommand(object):
@@ -21,10 +20,9 @@ def proctitle(prog, info=None):
 proctitle.last = ()
 
 
-class test_events(Case):
+class test_events(AppCase):
 
-    def setUp(self):
-        self.app = app_or_default()
+    def setup(self):
         self.ev = events.events(app=self.app)
 
     @patch('celery.events.dumper', 'evdump', lambda **kw: 'me dumper, you?')

+ 34 - 32
celery/tests/bin/test_worker.py

@@ -15,7 +15,6 @@ from kombu import Exchange, Queue
 from celery import Celery
 from celery import platforms
 from celery import signals
-from celery import current_app
 from celery.app import trace
 from celery.apps import worker as cd
 from celery.bin.worker import worker, main as worker_main
@@ -152,7 +151,7 @@ class test_Worker(WorkerAppCase):
 
     @disable_stdouts
     def test_loglevel_string(self):
-        worker = self.Worker(loglevel='INFO')
+        worker = self.Worker(app=self.app, loglevel='INFO')
         self.assertEqual(worker.loglevel, logging.INFO)
 
     def test_run_worker(self):
@@ -166,14 +165,14 @@ class test_Worker(WorkerAppCase):
         p = platforms.signals
         platforms.signals = Signals()
         try:
-            w = self.Worker()
+            w = self.Worker(app=self.app)
             w._isatty = False
             w.on_start()
             for sig in 'SIGINT', 'SIGHUP', 'SIGTERM':
                 self.assertIn(sig, handlers)
 
             handlers.clear()
-            w = self.Worker()
+            w = self.Worker(app=self.app)
             w._isatty = True
             w.on_start()
             for sig in 'SIGINT', 'SIGTERM':
@@ -184,7 +183,7 @@ class test_Worker(WorkerAppCase):
 
     @disable_stdouts
     def test_startup_info(self):
-        worker = self.Worker()
+        worker = self.Worker(app=self.app)
         worker.on_start()
         self.assertTrue(worker.startup_info())
         worker.loglevel = logging.DEBUG
@@ -211,7 +210,7 @@ class test_Worker(WorkerAppCase):
             app.loader = prev
 
         from celery.loaders.app import AppLoader
-        prev, app.loader = app.loader, AppLoader()
+        prev, app.loader = app.loader, AppLoader(app=self.app)
         try:
             self.assertTrue(worker.startup_info())
         finally:
@@ -227,18 +226,18 @@ class test_Worker(WorkerAppCase):
 
     @disable_stdouts
     def test_run(self):
-        self.Worker().on_start()
-        self.Worker(purge=True).on_start()
-        worker = self.Worker()
+        self.Worker(app=self.app).on_start()
+        self.Worker(app=self.app, purge=True).on_start()
+        worker = self.Worker(app=self.app)
         worker.on_start()
 
     @disable_stdouts
     def test_purge_messages(self):
-        self.Worker().purge_messages()
+        self.Worker(app=self.app).purge_messages()
 
     @disable_stdouts
     def test_init_queues(self):
-        app = current_app
+        app = self.app
         c = app.conf
         p, app.amqp.queues = app.amqp.queues, app.amqp.Queues({
             'celery': {'exchange': 'celery',
@@ -246,7 +245,7 @@ class test_Worker(WorkerAppCase):
             'video': {'exchange': 'video',
                       'routing_key': 'video'}})
         try:
-            worker = self.Worker()
+            worker = self.Worker(app=self.app)
             worker.setup_queues(['video'])
             self.assertIn('video', app.amqp.queues)
             self.assertIn('video', app.amqp.queues.consume_from)
@@ -256,10 +255,10 @@ class test_Worker(WorkerAppCase):
             c.CELERY_CREATE_MISSING_QUEUES = False
             del(app.amqp.queues)
             with self.assertRaises(ImproperlyConfigured):
-                self.Worker().setup_queues(['image'])
+                self.Worker(app=self.app).setup_queues(['image'])
             del(app.amqp.queues)
             c.CELERY_CREATE_MISSING_QUEUES = True
-            worker = self.Worker()
+            worker = self.Worker(app=self.app)
             worker.setup_queues(queues=['image'])
             self.assertIn('image', app.amqp.queues.consume_from)
             self.assertEqual(Queue('image', Exchange('image'),
@@ -269,31 +268,32 @@ class test_Worker(WorkerAppCase):
 
     @disable_stdouts
     def test_autoscale_argument(self):
-        worker1 = self.Worker(autoscale='10,3')
+        worker1 = self.Worker(app=self.app, autoscale='10,3')
         self.assertListEqual(worker1.autoscale, [10, 3])
-        worker2 = self.Worker(autoscale='10')
+        worker2 = self.Worker(app=self.app, autoscale='10')
         self.assertListEqual(worker2.autoscale, [10, 0])
 
     def test_include_argument(self):
-        worker1 = self.Worker(include='some.module')
+        worker1 = self.Worker(app=self.app, include='some.module')
         self.assertListEqual(worker1.include, ['some.module'])
-        worker2 = self.Worker(include='some.module,another.package')
+        worker2 = self.Worker(app=self.app,
+                              include='some.module,another.package')
         self.assertListEqual(
             worker2.include,
             ['some.module', 'another.package'],
         )
-        self.Worker(include=['os', 'sys'])
+        self.Worker(app=self.app, include=['os', 'sys'])
 
     @disable_stdouts
     def test_unknown_loglevel(self):
         with self.assertRaises(SystemExit):
             worker(app=self.app).run(loglevel='ALIEN')
-        worker1 = self.Worker(loglevel=0xFFFF)
+        worker1 = self.Worker(app=self.app, loglevel=0xFFFF)
         self.assertEqual(worker1.loglevel, 0xFFFF)
 
     @disable_stdouts
     def test_warns_if_running_as_privileged_user(self):
-        app = current_app
+        app = self.app
         if app.IS_WINDOWS:
             raise SkipTest('Not applicable on Windows')
 
@@ -305,14 +305,14 @@ class test_Worker(WorkerAppCase):
             with self.assertWarnsRegex(
                     RuntimeWarning,
                     r'superuser privileges is discouraged'):
-                worker = self.Worker()
+                worker = self.Worker(app=self.app)
                 worker.on_start()
         finally:
             os.getuid = prev
 
     @disable_stdouts
     def test_redirect_stdouts(self):
-        self.Worker(redirect_stdouts=False)
+        self.Worker(app=self.app, redirect_stdouts=False)
         with self.assertRaises(AttributeError):
             sys.stdout.logger
 
@@ -322,7 +322,7 @@ class test_Worker(WorkerAppCase):
             self.app.log.redirect_stdouts, Mock(),
         )
         try:
-            worker = self.Worker(redirect_stoutds=True)
+            worker = self.Worker(app=self.app, redirect_stoutds=True)
             worker._custom_logging = True
             worker.on_start()
             self.assertFalse(self.app.log.redirect_stdouts.called)
@@ -330,14 +330,16 @@ class test_Worker(WorkerAppCase):
             self.app.log.redirect_stdouts = prev
 
     def test_setup_logging_no_color(self):
-        worker = self.Worker(redirect_stdouts=False, no_color=True)
+        worker = self.Worker(
+            app=self.app, redirect_stdouts=False, no_color=True,
+        )
         prev, self.app.log.setup = self.app.log.setup, Mock()
         worker.setup_logging()
         self.assertFalse(self.app.log.setup.call_args[1]['colorize'])
 
     @disable_stdouts
     def test_startup_info_pool_is_str(self):
-        worker = self.Worker(redirect_stdouts=False)
+        worker = self.Worker(app=self.app, redirect_stdouts=False)
         worker.pool_cls = 'foo'
         worker.startup_info()
 
@@ -349,7 +351,7 @@ class test_Worker(WorkerAppCase):
             logging_setup[0] = True
 
         try:
-            worker = self.Worker(redirect_stdouts=False)
+            worker = self.Worker(app=self.app, redirect_stdouts=False)
             worker.app.log.__class__._setup = False
             worker.setup_logging()
             self.assertTrue(logging_setup[0])
@@ -367,7 +369,7 @@ class test_Worker(WorkerAppCase):
             def osx_proxy_detection_workaround(self):
                 self.proxy_workaround_installed = True
 
-        worker = OSXWorker(redirect_stdouts=False)
+        worker = OSXWorker(app=self.app, redirect_stdouts=False)
 
         def install_HUP_nosupport(controller):
             controller.hup_not_supported_installed = True
@@ -400,7 +402,7 @@ class test_Worker(WorkerAppCase):
         prev = cd.install_worker_restart_handler
         cd.install_worker_restart_handler = install_worker_restart_handler
         try:
-            worker = self.Worker()
+            worker = self.Worker(app=self.app)
             worker.app.IS_OSX = False
             worker.install_platform_tweaks(Controller())
             self.assertTrue(restart_worker_handler_installed[0])
@@ -415,7 +417,7 @@ class test_Worker(WorkerAppCase):
         def on_worker_ready(**kwargs):
             worker_ready_sent[0] = True
 
-        self.Worker().on_consumer_ready(object())
+        self.Worker(app=self.app).on_consumer_ready(object())
         self.assertTrue(worker_ready_sent[0])
 
 
@@ -430,7 +432,7 @@ class test_funs(WorkerAppCase):
             __import__('setproctitle')
         except ImportError:
             raise SkipTest('setproctitle not installed')
-        worker = Worker(hostname='xyzza')
+        worker = Worker(app=self.app, hostname='xyzza')
         prev1, sys.argv = sys.argv, ['Arg0']
         try:
             st = worker.set_process_status('Running')
@@ -452,7 +454,7 @@ class test_funs(WorkerAppCase):
     @disable_stdouts
     def test_parse_options(self):
         cmd = worker()
-        cmd.app = current_app
+        cmd.app = self.app
         opts, args = cmd.parse_options('worker', ['--concurrency=512'])
         self.assertEqual(opts.concurrency, 512)
 

+ 2 - 18
celery/tests/case.py

@@ -30,7 +30,6 @@ from nose import SkipTest
 from kombu.log import NullHandler
 from kombu.utils import nested
 
-from celery.app import app_or_default
 from celery.five import (
     WhateverIO, builtins, items, reraise,
     string_t, values, open_fqdn,
@@ -236,9 +235,7 @@ def wrap_logger(logger, loglevel=logging.ERROR):
 
 
 @contextmanager
-def eager_tasks():
-    app = app_or_default()
-
+def eager_tasks(app):
     prev = app.conf.CELERY_ALWAYS_EAGER
     app.conf.CELERY_ALWAYS_EAGER = True
     try:
@@ -247,19 +244,6 @@ def eager_tasks():
         app.conf.CELERY_ALWAYS_EAGER = prev
 
 
-def with_eager_tasks(fun):
-
-    @wraps(fun)
-    def _inner(*args, **kwargs):
-        app = app_or_default()
-        prev = app.conf.CELERY_ALWAYS_EAGER
-        app.conf.CELERY_ALWAYS_EAGER = True
-        try:
-            return fun(*args, **kwargs)
-        finally:
-            app.conf.CELERY_ALWAYS_EAGER = prev
-
-
 def with_environ(env_name, env_value):
 
     def _envpatched(fun):
@@ -547,7 +531,7 @@ def patch_many(*targets):
 
 
 @contextmanager
-def patch_settings(app=None, **config):
+def patch_settings(app, **config):
     if app is None:
         from celery import current_app
         app = current_app

+ 13 - 12
celery/tests/compat_modules/test_sets.py

@@ -4,12 +4,11 @@ import anyjson
 
 from mock import Mock, patch
 
-from celery import current_app
 from celery.task import Task
 from celery.task.sets import subtask, TaskSet
 from celery.canvas import Signature
 
-from celery.tests.case import Case
+from celery.tests.case import AppCase
 
 
 class MockTask(Task):
@@ -27,7 +26,7 @@ class MockTask(Task):
         return (args, kwargs, options)
 
 
-class test_subtask(Case):
+class test_subtask(AppCase):
 
     def test_behaves_like_type(self):
         s = subtask('tasks.add', (2, 2), {'cache': True},
@@ -104,15 +103,15 @@ class test_subtask(Case):
         self.assertDictEqual(dict(cls(*args)), dict(s))
 
 
-class test_TaskSet(Case):
+class test_TaskSet(AppCase):
 
     def test_task_arg_can_be_iterable__compat(self):
         ts = TaskSet([MockTask.subtask((i, i))
-                      for i in (2, 4, 8)])
+                      for i in (2, 4, 8)], app=self.app)
         self.assertEqual(len(ts), 3)
 
     def test_respects_ALWAYS_EAGER(self):
-        app = current_app
+        app = self.app
 
         class MockTaskSet(TaskSet):
             applied = 0
@@ -122,6 +121,7 @@ class test_TaskSet(Case):
 
         ts = MockTaskSet(
             [MockTask.subtask((i, i)) for i in (2, 4, 8)],
+            app=self.app,
         )
         app.conf.CELERY_ALWAYS_EAGER = True
         try:
@@ -145,7 +145,7 @@ class test_TaskSet(Case):
                 applied[0] += 1
 
         ts = TaskSet([mocksubtask(MockTask, (i, i))
-                      for i in (2, 4, 8)])
+                      for i in (2, 4, 8)], app=self.app)
         ts.apply_async()
         self.assertEqual(applied[0], 3)
 
@@ -158,9 +158,10 @@ class test_TaskSet(Case):
 
         # setting current_task
 
-        @current_app.task
+        @self.app.task
         def xyz():
             pass
+
         from celery._state import _task_stack
         xyz.push_request()
         _task_stack.push(xyz)
@@ -180,21 +181,21 @@ class test_TaskSet(Case):
                 applied[0] += 1
 
         ts = TaskSet([mocksubtask(MockTask, (i, i))
-                      for i in (2, 4, 8)])
+                      for i in (2, 4, 8)], app=self.app)
         ts.apply()
         self.assertEqual(applied[0], 3)
 
     def test_set_app(self):
-        ts = TaskSet([])
+        ts = TaskSet([], app=self.app)
         ts.app = 42
         self.assertEqual(ts.app, 42)
 
     def test_set_tasks(self):
-        ts = TaskSet([])
+        ts = TaskSet([], app=self.app)
         ts.tasks = [1, 2, 3]
         self.assertEqual(ts, [1, 2, 3])
 
     def test_set_Publisher(self):
-        ts = TaskSet([])
+        ts = TaskSet([], app=self.app)
         ts.Publisher = 42
         self.assertEqual(ts.Publisher, 42)

+ 4 - 4
celery/tests/events/test_cursesmon.py

@@ -2,7 +2,7 @@ from __future__ import absolute_import
 
 from nose import SkipTest
 
-from celery.tests.case import Case
+from celery.tests.case import AppCase
 
 
 class MockWindow(object):
@@ -11,16 +11,16 @@ class MockWindow(object):
         return self.y, self.x
 
 
-class test_CursesDisplay(Case):
+class test_CursesDisplay(AppCase):
 
-    def setUp(self):
+    def setup(self):
         try:
             import curses  # noqa
         except ImportError:
             raise SkipTest('curses monitor requires curses')
 
         from celery.events import cursesmon
-        self.monitor = cursesmon.CursesMonitor(object())
+        self.monitor = cursesmon.CursesMonitor(object(), app=self.app)
         self.win = MockWindow()
         self.monitor.win = self.win
 

+ 6 - 9
celery/tests/events/test_snapshot.py

@@ -2,10 +2,9 @@ from __future__ import absolute_import
 
 from mock import patch
 
-from celery.app import app_or_default
 from celery.events import Events
 from celery.events.snapshot import Polaroid, evcam
-from celery.tests.case import Case
+from celery.tests.case import AppCase
 
 
 class TRef(object):
@@ -28,10 +27,9 @@ class MockTimer(object):
 timer = MockTimer()
 
 
-class test_Polaroid(Case):
+class test_Polaroid(AppCase):
 
-    def setUp(self):
-        self.app = app_or_default()
+    def setup(self):
         self.state = self.app.events.State()
 
     def test_constructor(self):
@@ -99,7 +97,7 @@ class test_Polaroid(Case):
         self.assertEqual(shutter_signal_sent[0], 1)
 
 
-class test_evcam(Case):
+class test_evcam(AppCase):
 
     class MockReceiver(object):
         raise_keyboard_interrupt = False
@@ -113,12 +111,11 @@ class test_evcam(Case):
         def Receiver(self, *args, **kwargs):
             return test_evcam.MockReceiver()
 
-    def setUp(self):
-        self.app = app_or_default()
+    def setup(self):
         self.prev, self.app.events = self.app.events, self.MockEvents()
         self.app.events.app = self.app
 
-    def tearDown(self):
+    def teardown(self):
         self.app.events = self.prev
 
     def test_evcam(self):

+ 3 - 3
celery/tests/security/case.py

@@ -2,12 +2,12 @@ from __future__ import absolute_import
 
 from nose import SkipTest
 
-from celery.tests.case import Case
+from celery.tests.case import AppCase
 
 
-class SecurityCase(Case):
+class SecurityCase(AppCase):
 
-    def setUp(self):
+    def setup(self):
         try:
             from OpenSSL import crypto  # noqa
         except ImportError:

+ 40 - 28
celery/tests/security/test_security.py

@@ -18,10 +18,9 @@ from __future__ import absolute_import
 
 from mock import Mock, patch
 
-from celery import current_app
 from celery.exceptions import ImproperlyConfigured, SecurityError
 from celery.five import builtins
-from celery.security import setup_security, disable_untrusted_serializers
+from celery.security import disable_untrusted_serializers
 from celery.security.utils import reraise_errors
 from kombu.serialization import registry
 
@@ -55,11 +54,14 @@ class test_security(SecurityCase):
         disabled = registry._disabled_content_types
         self.assertEqual(0, len(disabled))
 
-        current_app.conf.CELERY_TASK_SERIALIZER = 'json'
-
-        setup_security()
-        self.assertIn('application/x-python-serialize', disabled)
-        disabled.clear()
+        prev, self.app.conf.CELERY_TASK_SERIALIZER = (
+            self.app.conf.CELERY_TASK_SERIALIZER, 'json')
+        try:
+            self.app.setup_security()
+            self.assertIn('application/x-python-serialize', disabled)
+            disabled.clear()
+        finally:
+            self.app.conf.CELERY_TASK_SERIALIZER = prev
 
     @patch('celery.security.register_auth')
     @patch('celery.security.disable_untrusted_serializers')
@@ -74,29 +76,39 @@ class test_security(SecurityCase):
             finally:
                 calls[0] += 1
 
-        with mock_open(side_effect=effect):
-            with patch('celery.security.registry') as registry:
-                store = Mock()
-                setup_security(['json'], key, cert, store)
-                dis.assert_called_with(['json'])
-                reg.assert_called_with('A', 'B', store, 'sha1', 'json')
-                registry._set_default_serializer.assert_called_with('auth')
+        prev, self.app.conf.CELERY_TASK_SERIALIZER = (
+                self.app.conf.CELERY_TASK_SERIALIZER, 'auth')
+        try:
+            with mock_open(side_effect=effect):
+                with patch('celery.security.registry') as registry:
+                    store = Mock()
+                    self.app.setup_security(['json'], key, cert, store)
+                    dis.assert_called_with(['json'])
+                    reg.assert_called_with('A', 'B', store, 'sha1', 'json')
+                    registry._set_default_serializer.assert_called_with('auth')
+        finally:
+            self.app.conf.CELERY_TASK_SERIALIZER = prev
 
     def test_security_conf(self):
-        current_app.conf.CELERY_TASK_SERIALIZER = 'auth'
-
-        self.assertRaises(ImproperlyConfigured, setup_security)
-
-        _import = builtins.__import__
-
-        def import_hook(name, *args, **kwargs):
-            if name == 'OpenSSL':
-                raise ImportError
-            return _import(name, *args, **kwargs)
-
-        builtins.__import__ = import_hook
-        self.assertRaises(ImproperlyConfigured, setup_security)
-        builtins.__import__ = _import
+        prev, self.app.conf.CELERY_TASK_SERIALIZER = (
+            self.app.conf.CELERY_TASK_SERIALIZER, 'auth')
+        try:
+            with self.assertRaises(ImproperlyConfigured):
+                self.app.setup_security()
+
+            _import = builtins.__import__
+
+            def import_hook(name, *args, **kwargs):
+                if name == 'OpenSSL':
+                    raise ImportError
+                return _import(name, *args, **kwargs)
+
+            builtins.__import__ = import_hook
+            with self.assertRaises(ImproperlyConfigured):
+                self.app.setup_security()
+            builtins.__import__ = _import
+        finally:
+            self.app.conf.CELERY_TASK_SERIALIZER = prev
 
     def test_reraise_errors(self):
         with self.assertRaises(SecurityError):

+ 27 - 28
celery/tests/tasks/test_chord.py

@@ -3,26 +3,25 @@ from __future__ import absolute_import
 from mock import patch
 from contextlib import contextmanager
 
+from celery import group
 from celery import canvas
-from celery import current_app
 from celery import result
 from celery.exceptions import ChordError
 from celery.five import range
 from celery.result import AsyncResult, GroupResult, EagerResult
-from celery.task import task, TaskSet
 from celery.tests.case import AppCase, Mock
 
 passthru = lambda x: x
 
 
-@current_app.task
-def add(x, y):
-    return x + y
+class ChordCase(AppCase):
 
+    def setup(self):
 
-@current_app.task
-def callback(r):
-    return r
+        @self.app.task
+        def add(x, y):
+            return x + y
+        self.add = add
 
 
 class TSR(GroupResult):
@@ -53,8 +52,8 @@ class TSRNoReport(TSR):
 
 
 @contextmanager
-def patch_unlock_retry():
-    unlock = current_app.tasks['celery.chord_unlock']
+def patch_unlock_retry(app):
+    unlock = app.tasks['celery.chord_unlock']
     retry = Mock()
     prev, unlock.retry = unlock.retry, retry
     try:
@@ -63,7 +62,7 @@ def patch_unlock_retry():
         unlock.retry = prev
 
 
-class test_unlock_chord_task(AppCase):
+class test_unlock_chord_task(ChordCase):
 
     @patch('celery.result.GroupResult')
     def test_unlock_ready(self, GroupResult):
@@ -133,7 +132,7 @@ class test_unlock_chord_task(AppCase):
     def _chord_context(self, ResultCls, setup=None, **kwargs):
         with patch('celery.result.GroupResult'):
 
-            @task()
+            @self.app.task()
             def callback(*args, **kwargs):
                 pass
 
@@ -143,7 +142,7 @@ class test_unlock_chord_task(AppCase):
             callback_s.id = 'callback_id'
             fail_current = self.app.backend.fail_from_current_stack = Mock()
             try:
-                with patch_unlock_retry() as (unlock, retry):
+                with patch_unlock_retry(self.app) as (unlock, retry):
                     subtask, canvas.maybe_subtask = (
                         canvas.maybe_subtask, passthru,
                     )
@@ -173,19 +172,19 @@ class test_unlock_chord_task(AppCase):
             retry.assert_called_with(countdown=10, max_retries=30)
 
     def test_is_in_registry(self):
-        self.assertIn('celery.chord_unlock', current_app.tasks)
+        self.assertIn('celery.chord_unlock', self.app.tasks)
 
 
-class test_chord(AppCase):
+class test_chord(ChordCase):
 
     def test_eager(self):
         from celery import chord
 
-        @task()
+        @self.app.task()
         def addX(x, y):
             return x + y
 
-        @task()
+        @self.app.task()
         def sumX(n):
             return sum(n)
 
@@ -207,8 +206,8 @@ class test_chord(AppCase):
         m.AsyncResult = AsyncResult
         prev, chord._type = chord._type, m
         try:
-            x = chord(add.s(i, i) for i in range(10))
-            body = add.s(2)
+            x = chord(self.add.s(i, i) for i in range(10))
+            body = self.add.s(2)
             result = x(body)
             self.assertTrue(result.id)
             # does not modify original subtask
@@ -219,18 +218,18 @@ class test_chord(AppCase):
             chord._type = prev
 
 
-class test_Chord_task(AppCase):
+class test_Chord_task(ChordCase):
 
     def test_run(self):
-        prev, current_app.backend = current_app.backend, Mock()
-        current_app.backend.cleanup = Mock()
-        current_app.backend.cleanup.__name__ = 'cleanup'
+        prev, self.app.backend = self.app.backend, Mock()
+        self.app.backend.cleanup = Mock()
+        self.app.backend.cleanup.__name__ = 'cleanup'
         try:
-            Chord = current_app.tasks['celery.chord']
+            Chord = self.app.tasks['celery.chord']
 
             body = dict()
-            Chord(TaskSet(add.subtask((i, i)) for i in range(5)), body)
-            Chord([add.subtask((j, j)) for j in range(5)], body)
-            self.assertEqual(current_app.backend.on_chord_apply.call_count, 2)
+            Chord(group(self.add.subtask((i, i)) for i in range(5)), body)
+            Chord([self.add.subtask((j, j)) for j in range(5)], body)
+            self.assertEqual(self.app.backend.on_chord_apply.call_count, 2)
         finally:
-            current_app.backend = prev
+            self.app.backend = prev

+ 5 - 5
celery/tests/tasks/test_http.py

@@ -13,7 +13,7 @@ from kombu.utils.encoding import from_utf8
 
 from celery.five import StringIO, items
 from celery.task import http
-from celery.tests.case import Case, eager_tasks
+from celery.tests.case import AppCase, Case, eager_tasks
 
 
 @contextmanager
@@ -96,7 +96,7 @@ class test_MutableURL(Case):
         self.assertEqual(url.query, {'zzz': 'xxx'})
 
 
-class test_HttpDispatch(Case):
+class test_HttpDispatch(AppCase):
 
     def test_dispatch_success(self):
         with mock_urlopen(success_response(100)):
@@ -139,16 +139,16 @@ class test_HttpDispatch(Case):
             self.assertEqual(d.dispatch(), 100)
 
 
-class test_URL(Case):
+class test_URL(AppCase):
 
     def test_URL_get_async(self):
-        with eager_tasks():
+        with eager_tasks(self.app):
             with mock_urlopen(success_response(100)):
                 d = http.URL('http://example.com/mul').get_async(x=10, y=10)
                 self.assertEqual(d.get(), 100)
 
     def test_URL_post_async(self):
-        with eager_tasks():
+        with eager_tasks(self.app):
             with mock_urlopen(success_response(100)):
                 d = http.URL('http://example.com/mul').post_async(x=10, y=10)
                 self.assertEqual(d.get(), 100)

+ 15 - 14
celery/tests/tasks/test_result.py

@@ -23,11 +23,6 @@ from celery.tests.case import AppCase
 from celery.tests.case import skip_if_quick
 
 
-@task()
-def mytask():
-    pass
-
-
 def mock_task(name, state, result):
     return dict(id=uuid(), name=name, state=state, result=result)
 
@@ -63,6 +58,11 @@ class test_AsyncResult(AppCase):
         for task in (self.task1, self.task2, self.task3, self.task4):
             save_result(self.app, task)
 
+        @self.app.task()
+        def mytask():
+            pass
+        self.mytask = mytask
+
     def test_compat_properties(self):
         x = self.app.AsyncResult('1')
         self.assertEqual(x.task_id, x.id)
@@ -153,10 +153,10 @@ class test_AsyncResult(AppCase):
         self.assertFalse(self.app.AsyncResult('1') == object())
 
     def test_reduce(self):
-        a1 = self.app.AsyncResult('uuid', task_name=mytask.name)
+        a1 = self.app.AsyncResult('uuid', task_name=self.mytask.name)
         restored = pickle.loads(pickle.dumps(a1))
         self.assertEqual(restored.id, 'uuid')
-        self.assertEqual(restored.task_name, mytask.name)
+        self.assertEqual(restored.task_name, self.mytask.name)
 
         a2 = self.app.AsyncResult('uuid')
         self.assertEqual(pickle.loads(pickle.dumps(a2)).id, 'uuid')
@@ -658,16 +658,17 @@ class test_pending_Group(AppCase):
             self.ts.join(timeout=1)
 
 
-class RaisingTask(Task):
-
-    def run(self, x, y):
-        raise KeyError('xy')
+class test_EagerResult(AppCase):
 
+    def setup(self):
 
-class test_EagerResult(AppCase):
+        @self.app.task
+        def raising(x, y):
+            raise KeyError(x, y)
+        self.raising = raising
 
     def test_wait_raises(self):
-        res = RaisingTask.apply(args=[3, 3])
+        res = self.raising.apply(args=[3, 3])
         with self.assertRaises(KeyError):
             res.wait()
         self.assertTrue(res.wait(propagate=False))
@@ -683,7 +684,7 @@ class test_EagerResult(AppCase):
         res.forget()
 
     def test_revoke(self):
-        res = RaisingTask.apply(args=[3, 3])
+        res = self.raising.apply(args=[3, 3])
         self.assertFalse(res.revoke())
 
 

+ 33 - 29
celery/tests/tasks/test_tasks.py

@@ -17,8 +17,6 @@ from celery.task import (
     periodic_task,
     PeriodicTask
 )
-from celery import current_app
-from celery.app import app_or_default
 from celery.exceptions import RetryTaskError
 from celery.execute import send_task
 from celery.five import items, range, string_t
@@ -27,11 +25,7 @@ from celery.schedules import crontab, crontab_parser, ParseException
 from celery.utils import uuid
 from celery.utils.timeutils import parse_iso8601, timedelta_seconds
 
-from celery.tests.case import Case, with_eager_tasks, WhateverIO
-
-
-def now():
-    return current_app.now()
+from celery.tests.case import AppCase, eager_tasks, WhateverIO
 
 
 def return_True(*args, **kwargs):
@@ -124,7 +118,7 @@ def retry_task_customexc(arg1, arg2, kwarg=1, **kwargs):
             raise current.retry(countdown=0, exc=exc)
 
 
-class test_task_retries(Case):
+class test_task_retries(AppCase):
 
     def test_retry(self):
         retry_task.__class__.max_retries = 3
@@ -207,7 +201,7 @@ class test_task_retries(Case):
         self.assertEqual(retry_task.iterations, 2)
 
 
-class test_canvas_utils(Case):
+class test_canvas_utils(AppCase):
 
     def test_si(self):
         self.assertTrue(retry_task.si())
@@ -226,7 +220,10 @@ class test_canvas_utils(Case):
         retry_task.on_success(1, 1, (), {})
 
 
-class test_tasks(Case):
+class test_tasks(AppCase):
+
+    def now(self):
+        return self.app.now()
 
     def test_unpickle_task(self):
         import pickle
@@ -312,8 +309,8 @@ class test_tasks(Case):
         # With eta.
         presult2 = T1.apply_async(
             kwargs=dict(name='George Costanza'),
-            eta=now() + timedelta(days=1),
-            expires=now() + timedelta(days=2),
+            eta=self.now() + timedelta(days=1),
+            expires=self.now() + timedelta(days=2),
         )
         self.assertNextTaskDataEqual(
             consumer, presult2, T1.name,
@@ -411,7 +408,7 @@ class test_tasks(Case):
                 del(app.amqp.__dict__['TaskProducer'])
 
     def test_get_publisher(self):
-        connection = app_or_default().connection()
+        connection = self.app.connection()
         p = increment_counter.get_publisher(connection, auto_declare=False,
                                             exchange='foo')
         self.assertEqual(p.exchange.name, 'foo')
@@ -471,14 +468,14 @@ class test_tasks(Case):
             t1.pop_request()
 
 
-class test_TaskSet(Case):
+class test_TaskSet(AppCase):
 
-    @with_eager_tasks
     def test_function_taskset(self):
-        subtasks = [return_True_task.s(i) for i in range(1, 6)]
-        ts = TaskSet(subtasks)
-        res = ts.apply_async()
-        self.assertListEqual(res.join(), [True, True, True, True, True])
+        with eager_tasks(self.app):
+            subtasks = [return_True_task.s(i) for i in range(1, 6)]
+            ts = TaskSet(subtasks)
+            res = ts.apply_async()
+            self.assertListEqual(res.join(), [True, True, True, True, True])
 
     def test_counter_taskset(self):
         increment_counter.count = 0
@@ -518,7 +515,7 @@ class test_TaskSet(Case):
         self.assertTrue(res.taskset_id.startswith(prefix))
 
 
-class test_apply_task(Case):
+class test_apply_task(AppCase):
 
     def test_apply_throw(self):
         with self.assertRaises(KeyError):
@@ -569,7 +566,10 @@ def my_periodic():
     pass
 
 
-class test_periodic_tasks(Case):
+class test_periodic_tasks(AppCase):
+
+    def now(self):
+        return self.app.now()
 
     def test_must_have_run_every(self):
         with self.assertRaises(NotImplementedError):
@@ -578,11 +578,11 @@ class test_periodic_tasks(Case):
     def test_remaining_estimate(self):
         s = my_periodic.run_every
         self.assertIsInstance(
-            s.remaining_estimate(s.maybe_make_aware(now())),
+            s.remaining_estimate(s.maybe_make_aware(self.now())),
             timedelta)
 
     def test_is_due_not_due(self):
-        due, remaining = my_periodic.run_every.is_due(now())
+        due, remaining = my_periodic.run_every.is_due(self.now())
         self.assertFalse(due)
         # This assertion may fail if executed in the
         # first minute of an hour, thus 59 instead of 60
@@ -591,7 +591,8 @@ class test_periodic_tasks(Case):
     def test_is_due(self):
         p = my_periodic
         due, remaining = p.run_every.is_due(
-            now() - p.run_every.run_every)
+            self.now() - p.run_every.run_every,
+        )
         self.assertTrue(due)
         self.assertEqual(remaining,
                          timedelta_seconds(p.run_every.run_every))
@@ -660,7 +661,7 @@ def patch_crontab_nowfun(cls, retval):
     return create_patcher
 
 
-class test_crontab_parser(Case):
+class test_crontab_parser(AppCase):
 
     def test_crontab_reduce(self):
         self.assertTrue(loads(dumps(crontab('*'))))
@@ -810,7 +811,7 @@ class test_crontab_parser(Case):
         self.assertFalse(crontab(minute='1') == object())
 
 
-class test_crontab_remaining_estimate(Case):
+class test_crontab_remaining_estimate(AppCase):
 
     def next_ocurrance(self, crontab, now):
         crontab.nowfun = lambda: now
@@ -982,10 +983,13 @@ class test_crontab_remaining_estimate(Case):
         self.assertEqual(next, datetime(2010, 5, 29, 0, 5))
 
 
-class test_crontab_is_due(Case):
+class test_crontab_is_due(AppCase):
+
+    def getnow(self):
+        return self.app.now()
 
-    def setUp(self):
-        self.now = now()
+    def setup(self):
+        self.now = self.getnow()
         self.next_minute = 60 - self.now.second - 1e-6 * self.now.microsecond
 
     def test_default_crontab_spec(self):

+ 4 - 3
celery/tests/worker/test_control.py

@@ -249,7 +249,7 @@ class test_ControlPanel(AppCase):
         self.panel.handle('report')
 
     def test_active(self):
-        r = TaskRequest(mytask.name, 'do re mi', (), {})
+        r = TaskRequest(mytask.name, 'do re mi', (), {}, app=self.app)
         worker_state.active_requests.add(r)
         try:
             self.assertTrue(self.panel.handle('dump_active'))
@@ -331,7 +331,7 @@ class test_ControlPanel(AppCase):
         consumer = Consumer(self.app)
         panel = self.create_panel(consumer=consumer)
         self.assertFalse(panel.handle('dump_schedule'))
-        r = TaskRequest(mytask.name, 'CAFEBABE', (), {})
+        r = TaskRequest(mytask.name, 'CAFEBABE', (), {}, app=self.app)
         consumer.timer.schedule.enter(
             consumer.timer.Entry(lambda x: x, (r, )),
             datetime.now() + timedelta(seconds=10))
@@ -343,7 +343,8 @@ class test_ControlPanel(AppCase):
     def test_dump_reserved(self):
         consumer = Consumer(self.app)
         worker_state.reserved_requests.add(
-            TaskRequest(mytask.name, uuid(), args=(2, 2), kwargs={}),
+            TaskRequest(mytask.name, uuid(), args=(2, 2), kwargs={},
+                        app=self.app),
         )
         try:
             panel = self.create_panel(consumer=consumer)

+ 25 - 25
celery/tests/worker/test_loops.py

@@ -15,7 +15,7 @@ from celery.tests.case import AppCase, body_from_sig
 
 class X(object):
 
-    def __init__(self, heartbeat=None, on_task=None):
+    def __init__(self, app, heartbeat=None, on_task=None):
         (
             self.obj,
             self.connection,
@@ -46,7 +46,7 @@ class X(object):
         self.hub.fire_timers.return_value = 1.7
         self.Hub = self.hub
         # need this for create_task_handler
-        _consumer = Consumer(Mock(), timer=Mock())
+        _consumer = Consumer(Mock(), timer=Mock(), app=app)
         self.obj.create_task_handler = _consumer.create_task_handler
         self.on_unknown_message = self.obj.on_unknown_message = Mock(
             name='on_unknown_message',
@@ -105,7 +105,7 @@ class test_asynloop(AppCase):
         self.add = add
 
     def test_setup_heartbeat(self):
-        x = X(heartbeat=10)
+        x = X(self.app, heartbeat=10)
         x.blueprint.state = CLOSE
         asynloop(*x.args)
         x.consumer.consume.assert_called_with()
@@ -115,7 +115,7 @@ class test_asynloop(AppCase):
         )
 
     def task_context(self, sig, **kwargs):
-        x, on_task = get_task_callback(**kwargs)
+        x, on_task = get_task_callback(self.app, **kwargs)
         body = body_from_sig(self.app, sig)
         message = Mock()
         strategy = x.obj.strategies[sig.task] = Mock()
@@ -153,7 +153,7 @@ class test_asynloop(AppCase):
         x.on_invalid_task.assert_called_with(body, msg, exc)
 
     def test_should_terminate(self):
-        x = X()
+        x = X(self.app)
         # XXX why aren't the errors propagated?!?
         state.should_terminate = True
         try:
@@ -163,7 +163,7 @@ class test_asynloop(AppCase):
             state.should_terminate = False
 
     def test_should_terminate_hub_close_raises(self):
-        x = X()
+        x = X(self.app)
         # XXX why aren't the errors propagated?!?
         state.should_terminate = True
         x.hub.close.side_effect = MemoryError()
@@ -174,7 +174,7 @@ class test_asynloop(AppCase):
             state.should_terminate = False
 
     def test_should_stop(self):
-        x = X()
+        x = X(self.app)
         state.should_stop = True
         try:
             with self.assertRaises(SystemExit):
@@ -183,13 +183,13 @@ class test_asynloop(AppCase):
             state.should_stop = False
 
     def test_updates_qos(self):
-        x = X()
+        x = X(self.app)
         x.qos.prev = 3
         x.qos.value = 3
         asynloop(*x.args, sleep=x.closer())
         self.assertFalse(x.qos.update.called)
 
-        x = X()
+        x = X(self.app)
         x.qos.prev = 1
         x.qos.value = 6
         asynloop(*x.args, sleep=x.closer())
@@ -198,7 +198,7 @@ class test_asynloop(AppCase):
         x.connection.transport.on_poll_start.assert_called_with()
 
     def test_poll_empty(self):
-        x = X()
+        x = X(self.app)
         x.hub.readers = {6: Mock()}
         x.close_then_error(x.connection.drain_nowait)
         x.hub.fire_timers.return_value = 33.37
@@ -209,7 +209,7 @@ class test_asynloop(AppCase):
         x.connection.transport.on_poll_empty.assert_called_with()
 
     def test_poll_readable(self):
-        x = X()
+        x = X(self.app)
         x.hub.readers = {6: Mock()}
         x.close_then_error(x.connection.drain_nowait, mod=4)
         x.hub.poller.poll.return_value = [(6, READ)]
@@ -219,7 +219,7 @@ class test_asynloop(AppCase):
         self.assertTrue(x.hub.poller.poll.called)
 
     def test_poll_readable_raises_Empty(self):
-        x = X()
+        x = X(self.app)
         x.hub.readers = {6: Mock()}
         x.close_then_error(x.connection.drain_nowait)
         x.hub.poller.poll.return_value = [(6, READ)]
@@ -230,7 +230,7 @@ class test_asynloop(AppCase):
         self.assertTrue(x.hub.poller.poll.called)
 
     def test_poll_writable(self):
-        x = X()
+        x = X(self.app)
         x.hub.writers = {6: Mock()}
         x.close_then_error(x.connection.drain_nowait)
         x.hub.poller.poll.return_value = [(6, WRITE)]
@@ -240,7 +240,7 @@ class test_asynloop(AppCase):
         self.assertTrue(x.hub.poller.poll.called)
 
     def test_poll_writable_none_registered(self):
-        x = X()
+        x = X(self.app)
         x.hub.writers = {6: Mock()}
         x.close_then_error(x.connection.drain_nowait)
         x.hub.poller.poll.return_value = [(7, WRITE)]
@@ -249,7 +249,7 @@ class test_asynloop(AppCase):
         self.assertTrue(x.hub.poller.poll.called)
 
     def test_poll_unknown_event(self):
-        x = X()
+        x = X(self.app)
         x.hub.writers = {6: Mock()}
         x.close_then_error(x.connection.drain_nowait)
         x.hub.poller.poll.return_value = [(6, 0)]
@@ -258,7 +258,7 @@ class test_asynloop(AppCase):
         self.assertTrue(x.hub.poller.poll.called)
 
     def test_poll_keep_draining_disabled(self):
-        x = X()
+        x = X(self.app)
         x.hub.writers = {6: Mock()}
         poll = x.hub.poller.poll
 
@@ -275,7 +275,7 @@ class test_asynloop(AppCase):
         self.assertFalse(x.connection.drain_nowait.called)
 
     def test_poll_err_writable(self):
-        x = X()
+        x = X(self.app)
         x.hub.writers = {6: Mock()}
         x.close_then_error(x.connection.drain_nowait)
         x.hub.poller.poll.return_value = [(6, ERR)]
@@ -285,7 +285,7 @@ class test_asynloop(AppCase):
         self.assertTrue(x.hub.poller.poll.called)
 
     def test_poll_write_generator(self):
-        x = X()
+        x = X(self.app)
 
         def Gen():
             yield 1
@@ -301,7 +301,7 @@ class test_asynloop(AppCase):
         self.assertFalse(x.hub.remove.called)
 
     def test_poll_write_generator_stopped(self):
-        x = X()
+        x = X(self.app)
 
         def Gen():
             raise StopIteration()
@@ -316,7 +316,7 @@ class test_asynloop(AppCase):
         x.hub.remove.assert_called_with(6)
 
     def test_poll_write_generator_raises(self):
-        x = X()
+        x = X(self.app)
 
         def Gen():
             raise ValueError('foo')
@@ -331,7 +331,7 @@ class test_asynloop(AppCase):
         x.hub.remove.assert_called_with(6)
 
     def test_poll_err_readable(self):
-        x = X()
+        x = X(self.app)
         x.hub.readers = {6: Mock()}
         x.close_then_error(x.connection.drain_nowait)
         x.hub.poller.poll.return_value = [(6, ERR)]
@@ -341,7 +341,7 @@ class test_asynloop(AppCase):
         self.assertTrue(x.hub.poller.poll.called)
 
     def test_poll_raises_ValueError(self):
-        x = X()
+        x = X(self.app)
         x.hub.readers = {6: Mock()}
         x.close_then_error(x.connection.drain_nowait)
         x.hub.poller.poll.side_effect = ValueError()
@@ -352,14 +352,14 @@ class test_asynloop(AppCase):
 class test_synloop(AppCase):
 
     def test_timeout_ignored(self):
-        x = X()
+        x = X(self.app)
         x.timeout_then_error(x.connection.drain_events)
         with self.assertRaises(socket.error):
             synloop(*x.args)
         self.assertEqual(x.connection.drain_events.call_count, 2)
 
     def test_updates_qos_when_changed(self):
-        x = X()
+        x = X(self.app)
         x.qos.prev = 2
         x.qos.value = 2
         x.timeout_then_error(x.connection.drain_events)
@@ -374,6 +374,6 @@ class test_synloop(AppCase):
         x.qos.update.assert_called_with()
 
     def test_ignores_socket_errors_when_closed(self):
-        x = X()
+        x = X(self.app)
         x.close_then_error(x.connection.drain_events)
         self.assertIsNone(synloop(*x.args))

+ 54 - 46
celery/tests/worker/test_request.py

@@ -203,7 +203,7 @@ class test_trace_task(AppCase):
                 if state == states.STARTED:
                     self._started.append(tid)
 
-        prev, mytask.backend = mytask.backend, Backend()
+        prev, mytask.backend = mytask.backend, Backend(self.app)
         mytask.track_started = True
 
         try:
@@ -350,7 +350,7 @@ class test_Request(AppCase):
             self.add.accept_magic_kwargs = False
 
     def test_task_wrapper_repr(self):
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         self.assertTrue(repr(tw))
 
     @patch('celery.worker.job.kwdict')
@@ -358,7 +358,7 @@ class test_Request(AppCase):
 
         prev, module.NEEDS_KWDICT = module.NEEDS_KWDICT, True
         try:
-            TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+            TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
             self.assertTrue(kwdict.called)
         finally:
             module.NEEDS_KWDICT = prev
@@ -366,23 +366,25 @@ class test_Request(AppCase):
     def test_sets_store_errors(self):
         mytask.ignore_result = True
         try:
-            tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+            tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'},
+                             app=self.app)
             self.assertFalse(tw.store_errors)
             mytask.store_errors_even_if_ignored = True
-            tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+            tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'},
+                             app=self.app)
             self.assertTrue(tw.store_errors)
         finally:
             mytask.ignore_result = False
             mytask.store_errors_even_if_ignored = False
 
     def test_send_event(self):
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         tw.eventer = MockEventDispatcher()
         tw.send_event('task-frobulated')
         self.assertIn('task-frobulated', tw.eventer.sent)
 
     def test_on_retry(self):
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         tw.eventer = MockEventDispatcher()
         try:
             raise RetryTaskError('foo', KeyError('moofoobar'))
@@ -399,7 +401,7 @@ class test_Request(AppCase):
             tw.on_failure(einfo)
 
     def test_compat_properties(self):
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         self.assertEqual(tw.task_id, tw.id)
         self.assertEqual(tw.task_name, tw.name)
         tw.task_id = 'ID'
@@ -410,7 +412,7 @@ class test_Request(AppCase):
     def test_terminate__task_started(self):
         pool = Mock()
         signum = signal.SIGKILL
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         with assert_signal_called(task_revoked, sender=tw.task,
                                   terminated=True,
                                   expired=False,
@@ -422,7 +424,7 @@ class test_Request(AppCase):
 
     def test_terminate__task_reserved(self):
         pool = Mock()
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         tw.time_start = None
         tw.terminate(pool, signal='KILL')
         self.assertFalse(pool.terminate_job.called)
@@ -431,7 +433,8 @@ class test_Request(AppCase):
 
     def test_revoked_expires_expired(self):
         tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'},
-                         expires=datetime.utcnow() - timedelta(days=1))
+                         expires=datetime.utcnow() - timedelta(days=1),
+                         app=self.app)
         with assert_signal_called(task_revoked, sender=tw.task,
                                   terminated=False,
                                   expired=True,
@@ -443,7 +446,8 @@ class test_Request(AppCase):
 
     def test_revoked_expires_not_expired(self):
         tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'},
-                         expires=datetime.utcnow() + timedelta(days=1))
+                         expires=datetime.utcnow() + timedelta(days=1),
+                         app=self.app)
         tw.revoked()
         self.assertNotIn(tw.id, revoked)
         self.assertNotEqual(
@@ -454,7 +458,8 @@ class test_Request(AppCase):
     def test_revoked_expires_ignore_result(self):
         mytask.ignore_result = True
         tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'},
-                         expires=datetime.utcnow() - timedelta(days=1))
+                         expires=datetime.utcnow() - timedelta(days=1),
+                         app=self.app)
         try:
             tw.revoked()
             self.assertIn(tw.id, revoked)
@@ -482,7 +487,8 @@ class test_Request(AppCase):
         app.mail_admins = mock_mail_admins
         mytask.send_error_emails = True
         try:
-            tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+            tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'},
+                             app=self.app)
 
             einfo = get_ei()
             tw.on_failure(einfo)
@@ -505,12 +511,12 @@ class test_Request(AppCase):
             mytask.send_error_emails = old_enable_mails
 
     def test_already_revoked(self):
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         tw._already_revoked = True
         self.assertTrue(tw.revoked())
 
     def test_revoked(self):
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         with assert_signal_called(task_revoked, sender=tw.task,
                                   terminated=False,
                                   expired=False,
@@ -521,13 +527,13 @@ class test_Request(AppCase):
             self.assertTrue(tw.acknowledged)
 
     def test_execute_does_not_execute_revoked(self):
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         revoked.add(tw.id)
         tw.execute()
 
     def test_execute_acks_late(self):
         mytask_raising.acks_late = True
-        tw = TaskRequest(mytask_raising.name, uuid(), [1])
+        tw = TaskRequest(mytask_raising.name, uuid(), [1], app=self.app)
         try:
             tw.execute()
             self.assertTrue(tw.acknowledged)
@@ -537,13 +543,13 @@ class test_Request(AppCase):
             mytask_raising.acks_late = False
 
     def test_execute_using_pool_does_not_execute_revoked(self):
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         revoked.add(tw.id)
         with self.assertRaises(TaskRevokedError):
             tw.execute_using_pool(None)
 
     def test_on_accepted_acks_early(self):
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         tw.on_accepted(pid=os.getpid(), time_accepted=time.time())
         self.assertTrue(tw.acknowledged)
         prev, module._does_debug = module._does_debug, False
@@ -553,7 +559,7 @@ class test_Request(AppCase):
             module._does_debug = prev
 
     def test_on_accepted_acks_late(self):
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         mytask.acks_late = True
         try:
             tw.on_accepted(pid=os.getpid(), time_accepted=time.time())
@@ -564,7 +570,7 @@ class test_Request(AppCase):
     def test_on_accepted_terminates(self):
         signum = signal.SIGKILL
         pool = Mock()
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         with assert_signal_called(task_revoked, sender=tw.task,
                                   terminated=True,
                                   expired=False,
@@ -575,7 +581,7 @@ class test_Request(AppCase):
             pool.terminate_job.assert_called_with(314, signum)
 
     def test_on_success_acks_early(self):
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         tw.time_start = 1
         tw.on_success(42)
         prev, module._does_info = module._does_info, False
@@ -586,7 +592,7 @@ class test_Request(AppCase):
             module._does_info = prev
 
     def test_on_success_BaseException(self):
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         tw.time_start = 1
         with self.assertRaises(SystemExit):
             try:
@@ -597,7 +603,7 @@ class test_Request(AppCase):
                 assert False
 
     def test_on_success_eventer(self):
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         tw.time_start = 1
         tw.eventer = Mock()
         tw.send_event = Mock()
@@ -605,7 +611,7 @@ class test_Request(AppCase):
         self.assertTrue(tw.send_event.called)
 
     def test_on_success_when_failure(self):
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         tw.time_start = 1
         tw.on_failure = Mock()
         try:
@@ -615,7 +621,7 @@ class test_Request(AppCase):
             self.assertTrue(tw.on_failure.called)
 
     def test_on_success_acks_late(self):
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         tw.time_start = 1
         mytask.acks_late = True
         try:
@@ -632,7 +638,7 @@ class test_Request(AppCase):
             except WorkerLostError:
                 return ExceptionInfo()
 
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         exc_info = get_ei()
         tw.on_failure(exc_info)
         self.assertEqual(mytask.backend.get_status(tw.id),
@@ -641,7 +647,8 @@ class test_Request(AppCase):
         mytask.ignore_result = True
         try:
             exc_info = get_ei()
-            tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+            tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'},
+                             app=self.app)
             tw.on_failure(exc_info)
             self.assertEqual(mytask.backend.get_status(tw.id),
                              states.PENDING)
@@ -649,7 +656,7 @@ class test_Request(AppCase):
             mytask.ignore_result = False
 
     def test_on_failure_acks_late(self):
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         tw.time_start = 1
         mytask.acks_late = True
         try:
@@ -665,13 +672,13 @@ class test_Request(AppCase):
     def test_from_message_invalid_kwargs(self):
         body = dict(task=mytask.name, id=1, args=(), kwargs='foo')
         with self.assertRaises(InvalidTaskError):
-            TaskRequest.from_message(None, body)
+            TaskRequest.from_message(None, body, app=self.app)
 
     @patch('celery.worker.job.error')
     @patch('celery.worker.job.warn')
     def test_on_timeout(self, warn, error):
 
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'}, app=self.app)
         tw.on_timeout(soft=True, timeout=1337)
         self.assertIn('Soft time limit', warn.call_args[0][0])
         tw.on_timeout(soft=False, timeout=1337)
@@ -681,7 +688,8 @@ class test_Request(AppCase):
 
         mytask.ignore_result = True
         try:
-            tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+            tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'},
+                             app=self.app)
             tw.on_timeout(soft=True, timeout=1336)
             self.assertEqual(mytask.backend.get_status(tw.id),
                              states.PENDING)
@@ -771,7 +779,7 @@ class test_Request(AppCase):
             mytask.pop_request()
 
     def test_task_wrapper_mail_attrs(self):
-        tw = TaskRequest(mytask.name, uuid(), [], {})
+        tw = TaskRequest(mytask.name, uuid(), [], {}, app=self.app)
         x = tw.success_msg % {
             'name': tw.name,
             'id': tw.id,
@@ -794,7 +802,7 @@ class test_Request(AppCase):
         m = Message(None, body=anyjson.dumps(body), backend='foo',
                     content_type='application/json',
                     content_encoding='utf-8')
-        tw = TaskRequest.from_message(m, m.decode())
+        tw = TaskRequest.from_message(m, m.decode(), app=self.app)
         self.assertIsInstance(tw, Request)
         self.assertEqual(tw.name, body['task'])
         self.assertEqual(tw.id, body['id'])
@@ -809,7 +817,7 @@ class test_Request(AppCase):
         m = Message(None, body=anyjson.dumps(body), backend='foo',
                     content_type='application/json',
                     content_encoding='utf-8')
-        tw = TaskRequest.from_message(m, m.decode())
+        tw = TaskRequest.from_message(m, m.decode(), app=self.app)
         self.assertIsInstance(tw, Request)
         self.assertEquals(tw.args, [])
         self.assertEquals(tw.kwargs, {})
@@ -820,7 +828,7 @@ class test_Request(AppCase):
                     content_type='application/json',
                     content_encoding='utf-8')
         with self.assertRaises(KeyError):
-            TaskRequest.from_message(m, m.decode())
+            TaskRequest.from_message(m, m.decode(), app=self.app)
 
     def test_from_message_nonexistant_task(self):
         body = {'task': 'cu.mytask.doesnotexist', 'id': uuid(),
@@ -829,11 +837,11 @@ class test_Request(AppCase):
                     content_type='application/json',
                     content_encoding='utf-8')
         with self.assertRaises(KeyError):
-            TaskRequest.from_message(m, m.decode())
+            TaskRequest.from_message(m, m.decode(), app=self.app)
 
     def test_execute(self):
         tid = uuid()
-        tw = TaskRequest(mytask.name, tid, [4], {'f': 'x'})
+        tw = TaskRequest(mytask.name, tid, [4], {'f': 'x'}, app=self.app)
         self.assertEqual(tw.execute(), 256)
         meta = mytask.backend.get_task_meta(tid)
         self.assertEqual(meta['result'], 256)
@@ -841,7 +849,7 @@ class test_Request(AppCase):
 
     def test_execute_success_no_kwargs(self):
         tid = uuid()
-        tw = TaskRequest(mytask_no_kwargs.name, tid, [4], {})
+        tw = TaskRequest(mytask_no_kwargs.name, tid, [4], {}, app=self.app)
         self.assertEqual(tw.execute(), 256)
         meta = mytask_no_kwargs.backend.get_task_meta(tid)
         self.assertEqual(meta['result'], 256)
@@ -849,7 +857,7 @@ class test_Request(AppCase):
 
     def test_execute_success_some_kwargs(self):
         tid = uuid()
-        tw = TaskRequest(mytask_some_kwargs.name, tid, [4], {})
+        tw = TaskRequest(mytask_some_kwargs.name, tid, [4], {}, app=self.app)
         self.assertEqual(tw.execute(), 256)
         meta = mytask_some_kwargs.backend.get_task_meta(tid)
         self.assertEqual(some_kwargs_scratchpad.get('task_id'), tid)
@@ -859,7 +867,7 @@ class test_Request(AppCase):
     def test_execute_ack(self):
         tid = uuid()
         tw = TaskRequest(mytask.name, tid, [4], {'f': 'x'},
-                         on_ack=on_ack)
+                         on_ack=on_ack, app=self.app)
         self.assertEqual(tw.execute(), 256)
         meta = mytask.backend.get_task_meta(tid)
         self.assertTrue(scratch['ACK'])
@@ -868,7 +876,7 @@ class test_Request(AppCase):
 
     def test_execute_fail(self):
         tid = uuid()
-        tw = TaskRequest(mytask_raising.name, tid, [4])
+        tw = TaskRequest(mytask_raising.name, tid, [4], app=self.app)
         self.assertIsInstance(tw.execute(), ExceptionInfo)
         meta = mytask_raising.backend.get_task_meta(tid)
         self.assertEqual(meta['status'], states.FAILURE)
@@ -876,7 +884,7 @@ class test_Request(AppCase):
 
     def test_execute_using_pool(self):
         tid = uuid()
-        tw = TaskRequest(mytask.name, tid, [4], {'f': 'x'})
+        tw = TaskRequest(mytask.name, tid, [4], {'f': 'x'}, app=self.app)
 
         class MockPool(BasePool):
             target = None
@@ -906,7 +914,7 @@ class test_Request(AppCase):
 
     def test_default_kwargs(self):
         tid = uuid()
-        tw = TaskRequest(mytask.name, tid, [4], {'f': 'x'})
+        tw = TaskRequest(mytask.name, tid, [4], {'f': 'x'}, app=self.app)
         self.assertDictEqual(
             tw.extend_with_default_kwargs(), {
                 'f': 'x',
@@ -926,7 +934,7 @@ class test_Request(AppCase):
     def _test_on_failure(self, exception, logger):
         app = self.app
         tid = uuid()
-        tw = TaskRequest(mytask.name, tid, [4], {'f': 'x'})
+        tw = TaskRequest(mytask.name, tid, [4], {'f': 'x'}, app=self.app)
         try:
             raise exception
         except Exception:

+ 41 - 40
celery/tests/worker/test_worker.py

@@ -23,7 +23,6 @@ from celery.five import Empty, range, Queue as FastQueue
 from celery.task import task as task_dec
 from celery.task import periodic_task as periodic_task_dec
 from celery.utils import uuid
-from celery.worker import WorkController
 from celery.worker import components
 from celery.worker import consumer
 from celery.worker.consumer import Consumer as __Consumer
@@ -243,7 +242,7 @@ class test_Consumer(AppCase):
         self.timer.stop()
 
     def test_info(self):
-        l = MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.task_consumer = Mock()
         l.qos = QoS(l.task_consumer.qos, 10)
         l.connection = Mock()
@@ -257,12 +256,12 @@ class test_Consumer(AppCase):
         self.assertTrue(info['broker'])
 
     def test_start_when_closed(self):
-        l = MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.blueprint.state = CLOSE
         l.start()
 
     def test_connection(self):
-        l = MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
 
         l.blueprint.start(l)
         self.assertIsInstance(l.connection, Connection)
@@ -287,7 +286,7 @@ class test_Consumer(AppCase):
         self.assertIsNone(l.task_consumer)
 
     def test_close_connection(self):
-        l = MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.blueprint.state = RUN
         step = find_step(l, consumer.Connection)
         conn = l.connection = Mock()
@@ -295,7 +294,7 @@ class test_Consumer(AppCase):
         self.assertTrue(conn.close.called)
         self.assertIsNone(l.connection)
 
-        l = MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         eventer = l.event_dispatcher = Mock()
         eventer.enabled = True
         heart = l.heart = MockHeart()
@@ -309,7 +308,7 @@ class test_Consumer(AppCase):
 
     @patch('celery.worker.consumer.warn')
     def test_receive_message_unknown(self, warn):
-        l = _MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.steps.pop()
         backend = Mock()
         m = create_message(backend, unknown={'baz': '!!!'})
@@ -323,7 +322,7 @@ class test_Consumer(AppCase):
     @patch('celery.worker.strategy.to_timestamp')
     def test_receive_message_eta_OverflowError(self, to_timestamp):
         to_timestamp.side_effect = OverflowError()
-        l = _MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.steps.pop()
         m = create_message(Mock(), task=foo_task.name,
                            args=('2, 2'),
@@ -340,7 +339,7 @@ class test_Consumer(AppCase):
 
     @patch('celery.worker.consumer.error')
     def test_receive_message_InvalidTaskError(self, error):
-        l = _MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.event_dispatcher = Mock()
         l.steps.pop()
         m = create_message(Mock(), task=foo_task.name,
@@ -354,7 +353,7 @@ class test_Consumer(AppCase):
 
     @patch('celery.worker.consumer.crit')
     def test_on_decode_error(self, crit):
-        l = Consumer(self.buffer.put, timer=self.timer)
+        l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
 
         class MockMessage(Mock):
             content_type = 'application/x-msgpack'
@@ -380,7 +379,7 @@ class test_Consumer(AppCase):
         return l.task_consumer.register_callback.call_args[0][0]
 
     def test_receieve_message(self):
-        l = Consumer(self.buffer.put, timer=self.timer)
+        l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
         l.event_dispatcher = Mock()
         m = create_message(Mock(), task=foo_task.name,
                            args=[2, 4, 8], kwargs={})
@@ -406,7 +405,7 @@ class test_Consumer(AppCase):
                 raise SyntaxError('bar')
 
         l = MockConsumer(self.buffer.put, timer=self.timer,
-                         send_events=False, pool=BasePool())
+                         send_events=False, pool=BasePool(), app=self.app)
         l.channel_errors = (KeyError, )
         with self.assertRaises(KeyError):
             l.start()
@@ -424,7 +423,7 @@ class test_Consumer(AppCase):
                 raise SyntaxError('bar')
 
         l = MockConsumer(self.buffer.put, timer=self.timer,
-                         send_events=False, pool=BasePool())
+                         send_events=False, pool=BasePool(), app=self.app)
 
         l.connection_errors = (KeyError, )
         self.assertRaises(SyntaxError, l.start)
@@ -439,7 +438,7 @@ class test_Consumer(AppCase):
                 self.obj.connection = None
                 raise socket.timeout(10)
 
-        l = MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.connection = Connection()
         l.task_consumer = Mock()
         l.connection.obj = l
@@ -455,7 +454,7 @@ class test_Consumer(AppCase):
                 self.obj.connection = None
                 raise socket.error('foo')
 
-        l = Consumer(self.buffer.put, timer=self.timer)
+        l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
         l.blueprint.state = RUN
         c = l.connection = Connection()
         l.connection.obj = l
@@ -476,7 +475,7 @@ class test_Consumer(AppCase):
             def drain_events(self, **kwargs):
                 self.obj.connection = None
 
-        l = Consumer(self.buffer.put, timer=self.timer)
+        l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
         l.connection = Connection()
         l.connection.obj = l
         l.task_consumer = Mock()
@@ -494,7 +493,7 @@ class test_Consumer(AppCase):
         l.task_consumer.qos.assert_called_with(prefetch_count=9)
 
     def test_ignore_errors(self):
-        l = MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.connection_errors = (AttributeError, KeyError, )
         l.channel_errors = (SyntaxError, )
         ignore_errors(l, Mock(side_effect=AttributeError('foo')))
@@ -505,7 +504,7 @@ class test_Consumer(AppCase):
 
     def test_apply_eta_task(self):
         from celery.worker import state
-        l = MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.qos = QoS(None, 10)
 
         task = object()
@@ -516,7 +515,7 @@ class test_Consumer(AppCase):
         self.assertIs(self.buffer.get_nowait(), task)
 
     def test_receieve_message_eta_isoformat(self):
-        l = _MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.steps.pop()
         m = create_message(
             Mock(), task=foo_task.name,
@@ -545,7 +544,7 @@ class test_Consumer(AppCase):
         l.timer.stop()
 
     def test_pidbox_callback(self):
-        l = MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         con = find_step(l, consumer.Control).box
         con.node = Mock()
         con.reset = Mock()
@@ -565,7 +564,7 @@ class test_Consumer(AppCase):
         self.assertTrue(con.reset.called)
 
     def test_revoke(self):
-        l = _MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.steps.pop()
         backend = Mock()
         id = uuid()
@@ -579,7 +578,7 @@ class test_Consumer(AppCase):
         self.assertTrue(self.buffer.empty())
 
     def test_receieve_message_not_registered(self):
-        l = _MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.steps.pop()
         backend = Mock()
         m = create_message(backend, task='x.X.31x', args=[2, 4, 8], kwargs={})
@@ -594,7 +593,7 @@ class test_Consumer(AppCase):
     @patch('celery.worker.consumer.warn')
     @patch('celery.worker.consumer.logger')
     def test_receieve_message_ack_raises(self, logger, warn):
-        l = Consumer(self.buffer.put, timer=self.timer)
+        l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
         backend = Mock()
         m = create_message(backend, args=[2, 4, 8], kwargs={})
 
@@ -612,7 +611,7 @@ class test_Consumer(AppCase):
         self.assertTrue(logger.critical.call_count)
 
     def test_receive_message_eta(self):
-        l = _MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.steps.pop()
         l.event_dispatcher = Mock()
         l.event_dispatcher._outbound_buffer = deque()
@@ -646,7 +645,7 @@ class test_Consumer(AppCase):
             self.buffer.get_nowait()
 
     def test_reset_pidbox_node(self):
-        l = MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         con = find_step(l, consumer.Control).box
         con.node = Mock()
         chan = con.node.channel = Mock()
@@ -660,7 +659,8 @@ class test_Consumer(AppCase):
         from celery.worker.pidbox import gPidbox
         pool = Mock()
         pool.is_green = True
-        l = MyKombuConsumer(self.buffer.put, timer=self.timer, pool=pool)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer, pool=pool,
+                            app=self.app)
         con = find_step(l, consumer.Control)
         self.assertIsInstance(con.box, gPidbox)
         con.start(l)
@@ -671,7 +671,8 @@ class test_Consumer(AppCase):
     def test__green_pidbox_node(self):
         pool = Mock()
         pool.is_green = True
-        l = MyKombuConsumer(self.buffer.put, timer=self.timer, pool=pool)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer, pool=pool,
+                            app=self.app)
         l.node = Mock()
         controller = find_step(l, consumer.Control)
 
@@ -733,7 +734,7 @@ class test_Consumer(AppCase):
     @patch('kombu.connection.Connection._establish_connection')
     @patch('kombu.utils.sleep')
     def test_connect_errback(self, sleep, connect):
-        l = MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         from kombu.transport.memory import Transport
         Transport.connection_errors = (StdChannelError, )
 
@@ -746,7 +747,7 @@ class test_Consumer(AppCase):
         connect.assert_called_with()
 
     def test_stop_pidbox_node(self):
-        l = MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         cont = find_step(l, consumer.Control)
         cont._node_stopped = Event()
         cont._node_shutdown = Event()
@@ -771,7 +772,7 @@ class test_Consumer(AppCase):
 
         init_callback = Mock()
         l = _Consumer(self.buffer.put, timer=self.timer,
-                      init_callback=init_callback)
+                      init_callback=init_callback, app=self.app)
         l.task_consumer = Mock()
         l.broadcast_consumer = Mock()
         l.qos = _QoS()
@@ -792,7 +793,7 @@ class test_Consumer(AppCase):
         self.assertEqual(l.qos.prev, l.qos.value)
 
         init_callback.reset_mock()
-        l = _Consumer(self.buffer.put, timer=self.timer,
+        l = _Consumer(self.buffer.put, timer=self.timer, app=self.app,
                       send_events=False, init_callback=init_callback)
         l.qos = _QoS()
         l.task_consumer = Mock()
@@ -804,7 +805,7 @@ class test_Consumer(AppCase):
         self.assertTrue(l.loop.call_count)
 
     def test_reset_connection_with_no_node(self):
-        l = Consumer(self.buffer.put, timer=self.timer)
+        l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
         l.steps.pop()
         self.assertEqual(None, l.pool)
         l.blueprint.start(l)
@@ -921,7 +922,7 @@ class test_WorkController(AppCase):
         self.assertTrue(worker.steps)
 
     def test_with_embedded_beat(self):
-        worker = WorkController(concurrency=1, loglevel=0, beat=True)
+        worker = self.app.WorkController(concurrency=1, loglevel=0, beat=True)
         self.assertTrue(worker.beat)
         self.assertIn(worker.beat, [w.obj for w in worker.steps])
 
@@ -933,7 +934,7 @@ class test_WorkController(AppCase):
         self.assertTrue(worker.autoscaler)
 
     def test_dont_stop_or_terminate(self):
-        worker = WorkController(concurrency=1, loglevel=0)
+        worker = self.app.WorkController(concurrency=1, loglevel=0)
         worker.stop()
         self.assertNotEqual(worker.blueprint.state, CLOSE)
         worker.terminate()
@@ -950,7 +951,7 @@ class test_WorkController(AppCase):
             worker.pool.signal_safe = sigsafe
 
     def test_on_timer_error(self):
-        worker = WorkController(concurrency=1, loglevel=0)
+        worker = self.app.WorkController(concurrency=1, loglevel=0)
 
         try:
             raise KeyError('foo')
@@ -960,7 +961,7 @@ class test_WorkController(AppCase):
             self.assertIn('KeyError', msg % args)
 
     def test_on_timer_tick(self):
-        worker = WorkController(concurrency=1, loglevel=10)
+        worker = self.app.WorkController(concurrency=1, loglevel=10)
 
         components.Timer(worker).on_timer_tick(30.0)
         xargs = self.comp_logger.debug.call_args[0]
@@ -974,7 +975,7 @@ class test_WorkController(AppCase):
         backend = Mock()
         m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
                            kwargs={})
-        task = Request.from_message(m, m.decode())
+        task = Request.from_message(m, m.decode(), app=self.app)
         worker._process_task(task)
         self.assertEqual(worker.pool.apply_async.call_count, 1)
         worker.pool.stop()
@@ -986,7 +987,7 @@ class test_WorkController(AppCase):
         backend = Mock()
         m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
                            kwargs={})
-        task = Request.from_message(m, m.decode())
+        task = Request.from_message(m, m.decode(), app=self.app)
         worker.steps = []
         worker.blueprint.state = RUN
         with self.assertRaises(KeyboardInterrupt):
@@ -1000,7 +1001,7 @@ class test_WorkController(AppCase):
         backend = Mock()
         m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
                            kwargs={})
-        task = Request.from_message(m, m.decode())
+        task = Request.from_message(m, m.decode(), app=self.app)
         worker.steps = []
         worker.blueprint.state = RUN
         with self.assertRaises(SystemExit):
@@ -1014,7 +1015,7 @@ class test_WorkController(AppCase):
         backend = Mock()
         m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
                            kwargs={})
-        task = Request.from_message(m, m.decode())
+        task = Request.from_message(m, m.decode(), app=self.app)
         worker._process_task(task)
         worker.pool.stop()