Просмотр исходного кода

100% coverage for celery.worker.pidbox

Ask Solem 12 лет назад
Родитель
Сommit
14f76ec2a5

+ 44 - 41
celery/tests/app/test_beat.py

@@ -7,15 +7,13 @@ from mock import Mock, call, patch
 from nose import SkipTest
 from nose import SkipTest
 from pickle import dumps, loads
 from pickle import dumps, loads
 
 
-from celery import current_app
 from celery import beat
 from celery import beat
 from celery import task
 from celery import task
 from celery.five import keys, string_t
 from celery.five import keys, string_t
 from celery.result import AsyncResult
 from celery.result import AsyncResult
 from celery.schedules import schedule
 from celery.schedules import schedule
-from celery.task.base import Task
 from celery.utils import uuid
 from celery.utils import uuid
-from celery.tests.utils import Case, patch_settings
+from celery.tests.utils import AppCase, patch_settings
 
 
 
 
 class Object(object):
 class Object(object):
@@ -47,7 +45,7 @@ class MockService(object):
         self.stopped = True
         self.stopped = True
 
 
 
 
-class test_ScheduleEntry(Case):
+class test_ScheduleEntry(AppCase):
     Entry = beat.ScheduleEntry
     Entry = beat.ScheduleEntry
 
 
     def create_entry(self, **kwargs):
     def create_entry(self, **kwargs):
@@ -113,7 +111,7 @@ class mScheduler(beat.Scheduler):
                           'args': args,
                           'args': args,
                           'kwargs': kwargs,
                           'kwargs': kwargs,
                           'options': options})
                           'options': options})
-        return AsyncResult(uuid())
+        return AsyncResult(uuid(), app=self.app)
 
 
 
 
 class mSchedulerSchedulingError(mScheduler):
 class mSchedulerSchedulingError(mScheduler):
@@ -144,17 +142,17 @@ always_due = mocked_schedule(True, 1)
 always_pending = mocked_schedule(False, 1)
 always_pending = mocked_schedule(False, 1)
 
 
 
 
-class test_Scheduler(Case):
+class test_Scheduler(AppCase):
 
 
     def test_custom_schedule_dict(self):
     def test_custom_schedule_dict(self):
         custom = {'foo': 'bar'}
         custom = {'foo': 'bar'}
-        scheduler = mScheduler(schedule=custom, lazy=True)
+        scheduler = mScheduler(app=self.app, schedule=custom, lazy=True)
         self.assertIs(scheduler.data, custom)
         self.assertIs(scheduler.data, custom)
 
 
     def test_apply_async_uses_registered_task_instances(self):
     def test_apply_async_uses_registered_task_instances(self):
         through_task = [False]
         through_task = [False]
 
 
-        class MockTask(Task):
+        class MockTask(self.app.Task):
 
 
             @classmethod
             @classmethod
             def apply_async(cls, *args, **kwargs):
             def apply_async(cls, *args, **kwargs):
@@ -162,7 +160,7 @@ class test_Scheduler(Case):
 
 
         assert MockTask.name in MockTask._get_app().tasks
         assert MockTask.name in MockTask._get_app().tasks
 
 
-        scheduler = mScheduler()
+        scheduler = mScheduler(app=self.app)
         scheduler.apply_async(scheduler.Entry(task=MockTask.name))
         scheduler.apply_async(scheduler.Entry(task=MockTask.name))
         self.assertTrue(through_task[0])
         self.assertTrue(through_task[0])
 
 
@@ -173,7 +171,7 @@ class test_Scheduler(Case):
             pass
             pass
         not_sync.apply_async = Mock()
         not_sync.apply_async = Mock()
 
 
-        s = mScheduler()
+        s = mScheduler(app=self.app)
         s._do_sync = Mock()
         s._do_sync = Mock()
         s.should_sync = Mock()
         s.should_sync = Mock()
         s.should_sync.return_value = True
         s.should_sync.return_value = True
@@ -187,16 +185,16 @@ class test_Scheduler(Case):
 
 
     @patch('celery.app.base.Celery.send_task')
     @patch('celery.app.base.Celery.send_task')
     def test_send_task(self, send_task):
     def test_send_task(self, send_task):
-        b = beat.Scheduler()
+        b = beat.Scheduler(app=self.app)
         b.send_task('tasks.add', countdown=10)
         b.send_task('tasks.add', countdown=10)
         send_task.assert_called_with('tasks.add', countdown=10)
         send_task.assert_called_with('tasks.add', countdown=10)
 
 
     def test_info(self):
     def test_info(self):
-        scheduler = mScheduler()
+        scheduler = mScheduler(app=self.app)
         self.assertIsInstance(scheduler.info, string_t)
         self.assertIsInstance(scheduler.info, string_t)
 
 
     def test_maybe_entry(self):
     def test_maybe_entry(self):
-        s = mScheduler()
+        s = mScheduler(app=self.app)
         entry = s.Entry(name='add every', task='tasks.add')
         entry = s.Entry(name='add every', task='tasks.add')
         self.assertIs(s._maybe_entry(entry.name, entry), entry)
         self.assertIs(s._maybe_entry(entry.name, entry), entry)
         self.assertTrue(s._maybe_entry('add every', {
         self.assertTrue(s._maybe_entry('add every', {
@@ -204,13 +202,13 @@ class test_Scheduler(Case):
         }))
         }))
 
 
     def test_set_schedule(self):
     def test_set_schedule(self):
-        s = mScheduler()
+        s = mScheduler(app=self.app)
         s.schedule = {'foo': 'bar'}
         s.schedule = {'foo': 'bar'}
         self.assertEqual(s.data, {'foo': 'bar'})
         self.assertEqual(s.data, {'foo': 'bar'})
 
 
     @patch('kombu.connection.Connection.ensure_connection')
     @patch('kombu.connection.Connection.ensure_connection')
     def test_ensure_connection_error_handler(self, ensure):
     def test_ensure_connection_error_handler(self, ensure):
-        s = mScheduler()
+        s = mScheduler(app=self.app)
         self.assertTrue(s._ensure_connected())
         self.assertTrue(s._ensure_connected())
         self.assertTrue(ensure.called)
         self.assertTrue(ensure.called)
         callback = ensure.call_args[0][0]
         callback = ensure.call_args[0][0]
@@ -218,29 +216,32 @@ class test_Scheduler(Case):
         callback(KeyError(), 5)
         callback(KeyError(), 5)
 
 
     def test_install_default_entries(self):
     def test_install_default_entries(self):
-        with patch_settings(CELERY_TASK_RESULT_EXPIRES=None,
+        with patch_settings(self.app,
+                            CELERY_TASK_RESULT_EXPIRES=None,
                             CELERYBEAT_SCHEDULE={}):
                             CELERYBEAT_SCHEDULE={}):
-            s = mScheduler()
+            s = mScheduler(app=self.app)
             s.install_default_entries({})
             s.install_default_entries({})
             self.assertNotIn('celery.backend_cleanup', s.data)
             self.assertNotIn('celery.backend_cleanup', s.data)
-        current_app.backend.supports_autoexpire = False
-        with patch_settings(CELERY_TASK_RESULT_EXPIRES=30,
+        self.app.backend.supports_autoexpire = False
+        with patch_settings(self.app,
+                            CELERY_TASK_RESULT_EXPIRES=30,
                             CELERYBEAT_SCHEDULE={}):
                             CELERYBEAT_SCHEDULE={}):
-            s = mScheduler()
+            s = mScheduler(app=self.app)
             s.install_default_entries({})
             s.install_default_entries({})
             self.assertIn('celery.backend_cleanup', s.data)
             self.assertIn('celery.backend_cleanup', s.data)
-        current_app.backend.supports_autoexpire = True
+        self.app.backend.supports_autoexpire = True
         try:
         try:
-            with patch_settings(CELERY_TASK_RESULT_EXPIRES=31,
+            with patch_settings(self.app,
+                                CELERY_TASK_RESULT_EXPIRES=31,
                                 CELERYBEAT_SCHEDULE={}):
                                 CELERYBEAT_SCHEDULE={}):
-                s = mScheduler()
+                s = mScheduler(app=self.app)
                 s.install_default_entries({})
                 s.install_default_entries({})
                 self.assertNotIn('celery.backend_cleanup', s.data)
                 self.assertNotIn('celery.backend_cleanup', s.data)
         finally:
         finally:
-            current_app.backend.supports_autoexpire = False
+            self.app.backend.supports_autoexpire = False
 
 
     def test_due_tick(self):
     def test_due_tick(self):
-        scheduler = mScheduler()
+        scheduler = mScheduler(app=self.app)
         scheduler.add(name='test_due_tick',
         scheduler.add(name='test_due_tick',
                       schedule=always_due,
                       schedule=always_due,
                       args=(1, 2),
                       args=(1, 2),
@@ -249,33 +250,33 @@ class test_Scheduler(Case):
 
 
     @patch('celery.beat.error')
     @patch('celery.beat.error')
     def test_due_tick_SchedulingError(self, error):
     def test_due_tick_SchedulingError(self, error):
-        scheduler = mSchedulerSchedulingError()
+        scheduler = mSchedulerSchedulingError(app=self.app)
         scheduler.add(name='test_due_tick_SchedulingError',
         scheduler.add(name='test_due_tick_SchedulingError',
                       schedule=always_due)
                       schedule=always_due)
         self.assertEqual(scheduler.tick(), 1)
         self.assertEqual(scheduler.tick(), 1)
         self.assertTrue(error.called)
         self.assertTrue(error.called)
 
 
     def test_due_tick_RuntimeError(self):
     def test_due_tick_RuntimeError(self):
-        scheduler = mSchedulerRuntimeError()
+        scheduler = mSchedulerRuntimeError(app=self.app)
         scheduler.add(name='test_due_tick_RuntimeError',
         scheduler.add(name='test_due_tick_RuntimeError',
                       schedule=always_due)
                       schedule=always_due)
         self.assertEqual(scheduler.tick(), scheduler.max_interval)
         self.assertEqual(scheduler.tick(), scheduler.max_interval)
 
 
     def test_pending_tick(self):
     def test_pending_tick(self):
-        scheduler = mScheduler()
+        scheduler = mScheduler(app=self.app)
         scheduler.add(name='test_pending_tick',
         scheduler.add(name='test_pending_tick',
                       schedule=always_pending)
                       schedule=always_pending)
         self.assertEqual(scheduler.tick(), 1)
         self.assertEqual(scheduler.tick(), 1)
 
 
     def test_honors_max_interval(self):
     def test_honors_max_interval(self):
-        scheduler = mScheduler()
+        scheduler = mScheduler(app=self.app)
         maxi = scheduler.max_interval
         maxi = scheduler.max_interval
         scheduler.add(name='test_honors_max_interval',
         scheduler.add(name='test_honors_max_interval',
                       schedule=mocked_schedule(False, maxi * 4))
                       schedule=mocked_schedule(False, maxi * 4))
         self.assertEqual(scheduler.tick(), maxi)
         self.assertEqual(scheduler.tick(), maxi)
 
 
     def test_ticks(self):
     def test_ticks(self):
-        scheduler = mScheduler()
+        scheduler = mScheduler(app=self.app)
         nums = [600, 300, 650, 120, 250, 36]
         nums = [600, 300, 650, 120, 250, 36]
         s = dict(('test_ticks%s' % i,
         s = dict(('test_ticks%s' % i,
                  {'schedule': mocked_schedule(False, j)})
                  {'schedule': mocked_schedule(False, j)})
@@ -284,20 +285,20 @@ class test_Scheduler(Case):
         self.assertEqual(scheduler.tick(), min(nums))
         self.assertEqual(scheduler.tick(), min(nums))
 
 
     def test_schedule_no_remain(self):
     def test_schedule_no_remain(self):
-        scheduler = mScheduler()
+        scheduler = mScheduler(app=self.app)
         scheduler.add(name='test_schedule_no_remain',
         scheduler.add(name='test_schedule_no_remain',
                       schedule=mocked_schedule(False, None))
                       schedule=mocked_schedule(False, None))
         self.assertEqual(scheduler.tick(), scheduler.max_interval)
         self.assertEqual(scheduler.tick(), scheduler.max_interval)
 
 
     def test_interface(self):
     def test_interface(self):
-        scheduler = mScheduler()
+        scheduler = mScheduler(app=self.app)
         scheduler.sync()
         scheduler.sync()
         scheduler.setup_schedule()
         scheduler.setup_schedule()
         scheduler.close()
         scheduler.close()
 
 
     def test_merge_inplace(self):
     def test_merge_inplace(self):
-        a = mScheduler()
-        b = mScheduler()
+        a = mScheduler(app=self.app)
+        b = mScheduler(app=self.app)
         a.update_from_dict({'foo': {'schedule': mocked_schedule(True, 10)},
         a.update_from_dict({'foo': {'schedule': mocked_schedule(True, 10)},
                             'bar': {'schedule': mocked_schedule(True, 20)}})
                             'bar': {'schedule': mocked_schedule(True, 20)}})
         b.update_from_dict({'bar': {'schedule': mocked_schedule(True, 40)},
         b.update_from_dict({'bar': {'schedule': mocked_schedule(True, 40)},
@@ -330,11 +331,12 @@ def create_persistent_scheduler(shelv=None):
     return MockPersistentScheduler, shelv
     return MockPersistentScheduler, shelv
 
 
 
 
-class test_PersistentScheduler(Case):
+class test_PersistentScheduler(AppCase):
 
 
     @patch('os.remove')
     @patch('os.remove')
     def test_remove_db(self, remove):
     def test_remove_db(self, remove):
-        s = create_persistent_scheduler()[0](schedule_filename='schedule')
+        s = create_persistent_scheduler()[0](app=self.app,
+                                             schedule_filename='schedule')
         s._remove_db()
         s._remove_db()
         remove.assert_has_calls(
         remove.assert_has_calls(
             [call('schedule' + suffix) for suffix in s.known_suffixes]
             [call('schedule' + suffix) for suffix in s.known_suffixes]
@@ -348,7 +350,8 @@ class test_PersistentScheduler(Case):
             s._remove_db()
             s._remove_db()
 
 
     def test_setup_schedule(self):
     def test_setup_schedule(self):
-        s = create_persistent_scheduler()[0](schedule_filename='schedule')
+        s = create_persistent_scheduler()[0](app=self.app,
+                                             schedule_filename='schedule')
         opens = s.persistence.open = Mock()
         opens = s.persistence.open = Mock()
         s._remove_db = Mock()
         s._remove_db = Mock()
 
 
@@ -383,14 +386,14 @@ class test_PersistentScheduler(Case):
         self.assertDictEqual(s._store['entries'], s.schedule)
         self.assertDictEqual(s._store['entries'], s.schedule)
 
 
 
 
-class test_Service(Case):
+class test_Service(AppCase):
 
 
     def get_service(self):
     def get_service(self):
         Scheduler, mock_shelve = create_persistent_scheduler()
         Scheduler, mock_shelve = create_persistent_scheduler()
-        return beat.Service(scheduler_cls=Scheduler), mock_shelve
+        return beat.Service(app=self.app, scheduler_cls=Scheduler), mock_shelve
 
 
     def test_pickleable(self):
     def test_pickleable(self):
-        s = beat.Service(scheduler_cls=Mock)
+        s = beat.Service(app=self.app, scheduler_cls=Mock)
         self.assertTrue(loads(dumps(s)))
         self.assertTrue(loads(dumps(s)))
 
 
     def test_start(self):
     def test_start(self):
@@ -442,7 +445,7 @@ class test_Service(Case):
         self.assertTrue(s._is_shutdown.isSet())
         self.assertTrue(s._is_shutdown.isSet())
 
 
 
 
-class test_EmbeddedService(Case):
+class test_EmbeddedService(AppCase):
 
 
     def test_start_stop_process(self):
     def test_start_stop_process(self):
         try:
         try:

+ 46 - 43
celery/tests/app/test_builtins.py

@@ -2,93 +2,97 @@ from __future__ import absolute_import
 
 
 from mock import Mock, patch
 from mock import Mock, patch
 
 
-from celery import current_app as app, group, task, chord
+from celery import group, shared_task, chord
 from celery.app import builtins
 from celery.app import builtins
 from celery.canvas import Signature
 from celery.canvas import Signature
 from celery.five import range
 from celery.five import range
 from celery._state import _task_stack
 from celery._state import _task_stack
-from celery.tests.utils import Case
+from celery.tests.utils import AppCase
 
 
 
 
-@task()
+@shared_task()
 def add(x, y):
 def add(x, y):
     return x + y
     return x + y
 
 
 
 
-@task()
+@shared_task()
 def xsum(x):
 def xsum(x):
     return sum(x)
     return sum(x)
 
 
 
 
-class test_backend_cleanup(Case):
+class test_backend_cleanup(AppCase):
 
 
     def test_run(self):
     def test_run(self):
-        prev = app.backend
-        app.backend.cleanup = Mock()
-        app.backend.cleanup.__name__ = 'cleanup'
+        prev = self.app.backend
+        self.app.backend.cleanup = Mock()
+        self.app.backend.cleanup.__name__ = 'cleanup'
         try:
         try:
-            cleanup_task = builtins.add_backend_cleanup_task(app)
+            cleanup_task = builtins.add_backend_cleanup_task(self.app)
             cleanup_task()
             cleanup_task()
-            self.assertTrue(app.backend.cleanup.called)
+            self.assertTrue(self.app.backend.cleanup.called)
         finally:
         finally:
-            app.backend = prev
+            self.app.backend = prev
 
 
 
 
-class test_map(Case):
+class test_map(AppCase):
 
 
     def test_run(self):
     def test_run(self):
 
 
-        @app.task()
+        @self.app.task()
         def map_mul(x):
         def map_mul(x):
             return x[0] * x[1]
             return x[0] * x[1]
 
 
-        res = app.tasks['celery.map'](map_mul, [(2, 2), (4, 4), (8, 8)])
+        res = self.app.tasks['celery.map'](
+            map_mul, [(2, 2), (4, 4), (8, 8)],
+        )
         self.assertEqual(res, [4, 16, 64])
         self.assertEqual(res, [4, 16, 64])
 
 
 
 
-class test_starmap(Case):
+class test_starmap(AppCase):
 
 
     def test_run(self):
     def test_run(self):
 
 
-        @app.task()
+        @self.app.task()
         def smap_mul(x, y):
         def smap_mul(x, y):
             return x * y
             return x * y
 
 
-        res = app.tasks['celery.starmap'](smap_mul, [(2, 2), (4, 4), (8, 8)])
+        res = self.app.tasks['celery.starmap'](
+            smap_mul, [(2, 2), (4, 4), (8, 8)],
+        )
         self.assertEqual(res, [4, 16, 64])
         self.assertEqual(res, [4, 16, 64])
 
 
 
 
-class test_chunks(Case):
+class test_chunks(AppCase):
 
 
     @patch('celery.canvas.chunks.apply_chunks')
     @patch('celery.canvas.chunks.apply_chunks')
     def test_run(self, apply_chunks):
     def test_run(self, apply_chunks):
 
 
-        @app.task()
+        @self.app.task()
         def chunks_mul(l):
         def chunks_mul(l):
             return l
             return l
 
 
-        app.tasks['celery.chunks'](
+        self.app.tasks['celery.chunks'](
             chunks_mul, [(2, 2), (4, 4), (8, 8)], 1,
             chunks_mul, [(2, 2), (4, 4), (8, 8)], 1,
         )
         )
         self.assertTrue(apply_chunks.called)
         self.assertTrue(apply_chunks.called)
 
 
 
 
-class test_group(Case):
+class test_group(AppCase):
 
 
-    def setUp(self):
-        self.prev = app.tasks.get('celery.group')
-        self.task = builtins.add_group_task(app)()
+    def setup(self):
+        self.prev = self.app.tasks.get('celery.group')
+        self.task = builtins.add_group_task(self.app)()
 
 
-    def tearDown(self):
-        app.tasks['celery.group'] = self.prev
+    def teardown(self):
+        self.app.tasks['celery.group'] = self.prev
 
 
     def test_apply_async_eager(self):
     def test_apply_async_eager(self):
         self.task.apply = Mock()
         self.task.apply = Mock()
-        app.conf.CELERY_ALWAYS_EAGER = True
+        self.app.conf.CELERY_ALWAYS_EAGER = True
         try:
         try:
             self.task.apply_async()
             self.task.apply_async()
         finally:
         finally:
-            app.conf.CELERY_ALWAYS_EAGER = False
+            self.app.conf.CELERY_ALWAYS_EAGER = False
         self.assertTrue(self.task.apply.called)
         self.assertTrue(self.task.apply.called)
 
 
     def test_apply(self):
     def test_apply(self):
@@ -125,14 +129,14 @@ class test_group(Case):
             _task_stack.pop()
             _task_stack.pop()
 
 
 
 
-class test_chain(Case):
+class test_chain(AppCase):
 
 
-    def setUp(self):
-        self.prev = app.tasks.get('celery.chain')
-        self.task = builtins.add_chain_task(app)()
+    def setup(self):
+        self.prev = self.app.tasks.get('celery.chain')
+        self.task = builtins.add_chain_task(self.app)()
 
 
-    def tearDown(self):
-        app.tasks['celery.chain'] = self.prev
+    def teardown(self):
+        self.app.tasks['celery.chain'] = self.prev
 
 
     def test_apply_async(self):
     def test_apply_async(self):
         c = add.s(2, 2) | add.s(4) | add.s(8)
         c = add.s(2, 2) | add.s(4) | add.s(8)
@@ -185,14 +189,14 @@ class test_chain(Case):
             self.assertListEqual(task.options['link_error'], [s('error')])
             self.assertListEqual(task.options['link_error'], [s('error')])
 
 
 
 
-class test_chord(Case):
+class test_chord(AppCase):
 
 
-    def setUp(self):
-        self.prev = app.tasks.get('celery.chord')
-        self.task = builtins.add_chord_task(app)()
+    def setup(self):
+        self.prev = self.app.tasks.get('celery.chord')
+        self.task = builtins.add_chord_task(self.app)()
 
 
-    def tearDown(self):
-        app.tasks['celery.chord'] = self.prev
+    def teardown(self):
+        self.app.tasks['celery.chord'] = self.prev
 
 
     def test_apply_async(self):
     def test_apply_async(self):
         x = chord([add.s(i, i) for i in range(10)], body=xsum.s())
         x = chord([add.s(i, i) for i in range(10)], body=xsum.s())
@@ -213,11 +217,10 @@ class test_chord(Case):
         self.assertEqual(body.options['chord'], 'some_chord_id')
         self.assertEqual(body.options['chord'], 'some_chord_id')
 
 
     def test_apply_eager(self):
     def test_apply_eager(self):
-        app.conf.CELERY_ALWAYS_EAGER = True
+        self.app.conf.CELERY_ALWAYS_EAGER = True
         try:
         try:
             x = chord([add.s(i, i) for i in range(10)], body=xsum.s())
             x = chord([add.s(i, i) for i in range(10)], body=xsum.s())
             r = x.apply_async()
             r = x.apply_async()
             self.assertEqual(r.get(), 90)
             self.assertEqual(r.get(), 90)
-
         finally:
         finally:
-            app.conf.CELERY_ALWAYS_EAGER = False
+            self.app.conf.CELERY_ALWAYS_EAGER = False

+ 5 - 8
celery/tests/app/test_loaders.py

@@ -7,7 +7,6 @@ import warnings
 from mock import Mock, patch
 from mock import Mock, patch
 
 
 from celery import loaders
 from celery import loaders
-from celery.app import app_or_default
 from celery.exceptions import (
 from celery.exceptions import (
     NotConfigured,
     NotConfigured,
     CPendingDeprecationWarning,
     CPendingDeprecationWarning,
@@ -47,7 +46,7 @@ class test_loaders(AppCase):
             self.assertIs(loaders.load_settings(), self.app.conf)
             self.assertIs(loaders.load_settings(), self.app.conf)
 
 
 
 
-class test_LoaderBase(Case):
+class test_LoaderBase(AppCase):
     message_options = {'subject': 'Subject',
     message_options = {'subject': 'Subject',
                        'body': 'Body',
                        'body': 'Body',
                        'sender': 'x@x.com',
                        'sender': 'x@x.com',
@@ -58,9 +57,8 @@ class test_LoaderBase(Case):
                       'password': 'qwerty',
                       'password': 'qwerty',
                       'timeout': 3}
                       'timeout': 3}
 
 
-    def setUp(self):
-        self.loader = DummyLoader()
-        self.app = app_or_default()
+    def setup(self):
+        self.loader = DummyLoader(app=self.app)
 
 
     def test_handlers_pass(self):
     def test_handlers_pass(self):
         self.loader.on_task_init('foo.task', 'feedface-cafebabe')
         self.loader.on_task_init('foo.task', 'feedface-cafebabe')
@@ -222,10 +220,9 @@ class test_DefaultLoader(Case):
         self.assertTrue(context_executed[0])
         self.assertTrue(context_executed[0])
 
 
 
 
-class test_AppLoader(Case):
+class test_AppLoader(AppCase):
 
 
-    def setUp(self):
-        self.app = app_or_default()
+    def setup(self):
         self.loader = AppLoader(app=self.app)
         self.loader = AppLoader(app=self.app)
 
 
     def test_on_worker_init(self):
     def test_on_worker_init(self):

+ 14 - 13
celery/tests/app/test_log.py

@@ -7,7 +7,6 @@ from tempfile import mktemp
 from mock import patch, Mock
 from mock import patch, Mock
 from nose import SkipTest
 from nose import SkipTest
 
 
-from celery import current_app
 from celery import signals
 from celery import signals
 from celery.app.log import Logging, TaskFormatter
 from celery.app.log import Logging, TaskFormatter
 from celery.utils.log import LoggingProxy
 from celery.utils.log import LoggingProxy
@@ -22,7 +21,6 @@ from celery.tests.utils import (
     AppCase, Case, override_stdouts, wrap_logger, get_handlers,
     AppCase, Case, override_stdouts, wrap_logger, get_handlers,
 )
 )
 
 
-log = current_app.log
 
 
 
 
 class test_TaskFormatter(Case):
 class test_TaskFormatter(Case):
@@ -110,7 +108,7 @@ class test_ColorFormatter(Case):
 class test_default_logger(AppCase):
 class test_default_logger(AppCase):
 
 
     def setup(self):
     def setup(self):
-        self.setup_logger = log.setup_logger
+        self.setup_logger = self.app.log.setup_logger
         self.get_logger = lambda n=None: get_logger(n) if n else logging.root
         self.get_logger = lambda n=None: get_logger(n) if n else logging.root
         signals.setup_logging.receivers[:] = []
         signals.setup_logging.receivers[:] = []
         Logging._setup = False
         Logging._setup = False
@@ -124,12 +122,12 @@ class test_default_logger(AppCase):
         self.assertIs(logger.parent, logging.root)
         self.assertIs(logger.parent, logging.root)
 
 
     def test_setup_logging_subsystem_misc(self):
     def test_setup_logging_subsystem_misc(self):
-        log.setup_logging_subsystem(loglevel=None)
+        self.app.log.setup_logging_subsystem(loglevel=None)
 
 
     def test_setup_logging_subsystem_misc2(self):
     def test_setup_logging_subsystem_misc2(self):
         self.app.conf.CELERYD_HIJACK_ROOT_LOGGER = True
         self.app.conf.CELERYD_HIJACK_ROOT_LOGGER = True
         try:
         try:
-            log.setup_logging_subsystem()
+            self.app.log.setup_logging_subsystem()
         finally:
         finally:
             self.app.conf.CELERYD_HIJACK_ROOT_LOGGER = False
             self.app.conf.CELERYD_HIJACK_ROOT_LOGGER = False
 
 
@@ -142,14 +140,14 @@ class test_default_logger(AppCase):
         logger.handlers[:] = []
         logger.handlers[:] = []
 
 
     def test_setup_logging_subsystem_colorize(self):
     def test_setup_logging_subsystem_colorize(self):
-        log.setup_logging_subsystem(colorize=None)
-        log.setup_logging_subsystem(colorize=True)
+        self.app.log.setup_logging_subsystem(colorize=None)
+        self.app.log.setup_logging_subsystem(colorize=True)
 
 
     def test_setup_logging_subsystem_no_mputil(self):
     def test_setup_logging_subsystem_no_mputil(self):
         from celery.utils import log as logtools
         from celery.utils import log as logtools
         mputil, logtools.mputil = logtools.mputil, None
         mputil, logtools.mputil = logtools.mputil, None
         try:
         try:
-            log.setup_logging_subsystem()
+            self.app.log.setup_logging_subsystem()
         finally:
         finally:
             logtools.mputil = mputil
             logtools.mputil = mputil
 
 
@@ -203,11 +201,14 @@ class test_default_logger(AppCase):
                                    root=False)
                                    root=False)
         try:
         try:
             with wrap_logger(logger) as sio:
             with wrap_logger(logger) as sio:
-                log.redirect_stdouts_to_logger(logger, loglevel=logging.ERROR)
+                self.app.log.redirect_stdouts_to_logger(
+                    logger, loglevel=logging.ERROR,
+                )
                 logger.error('foo')
                 logger.error('foo')
                 self.assertIn('foo', sio.getvalue())
                 self.assertIn('foo', sio.getvalue())
-                log.redirect_stdouts_to_logger(logger, stdout=False,
-                                               stderr=False)
+                self.app.log.redirect_stdouts_to_logger(
+                    logger, stdout=False, stderr=False,
+                )
         finally:
         finally:
             sys.stdout, sys.stderr = sys.__stdout__, sys.__stderr__
             sys.stdout, sys.stderr = sys.__stdout__, sys.__stderr__
 
 
@@ -250,7 +251,7 @@ class test_task_logger(test_default_logger):
         logging.root.manager.loggerDict.pop(logger.name, None)
         logging.root.manager.loggerDict.pop(logger.name, None)
         self.uid = uuid()
         self.uid = uuid()
 
 
-        @current_app.task
+        @self.app.task
         def test_task():
         def test_task():
             pass
             pass
         self.get_logger().handlers = []
         self.get_logger().handlers = []
@@ -263,7 +264,7 @@ class test_task_logger(test_default_logger):
         _task_stack.pop()
         _task_stack.pop()
 
 
     def setup_logger(self, *args, **kwargs):
     def setup_logger(self, *args, **kwargs):
-        return log.setup_task_loggers(*args, **kwargs)
+        return self.app.log.setup_task_loggers(*args, **kwargs)
 
 
     def get_logger(self, *args, **kwargs):
     def get_logger(self, *args, **kwargs):
         return get_task_logger("test_task_logger")
         return get_task_logger("test_task_logger")

+ 101 - 100
celery/tests/app/test_routes.py

@@ -1,19 +1,18 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
 
 
-from functools import wraps
+from contextlib import contextmanager
 
 
 from kombu import Exchange
 from kombu import Exchange
 from kombu.utils.functional import maybe_promise
 from kombu.utils.functional import maybe_promise
 
 
-from celery import current_app
 from celery.app import routes
 from celery.app import routes
 from celery.exceptions import QueueNotFound
 from celery.exceptions import QueueNotFound
 from celery.task import task
 from celery.task import task
-from celery.tests.utils import Case
+from celery.tests.utils import AppCase
 
 
 
 
-def Router(*args, **kwargs):
-    return routes.Router(*args, app=current_app, **kwargs)
+def Router(app, *args, **kwargs):
+    return routes.Router(*args, app=app, **kwargs)
 
 
 
 
 @task()
 @task()
@@ -21,72 +20,70 @@ def mytask():
     pass
     pass
 
 
 
 
-def E(queues):
+def E(app, queues):
     def expand(answer):
     def expand(answer):
-        return Router([], queues).expand_destination(answer)
+        return Router(app, [], queues).expand_destination(answer)
     return expand
     return expand
 
 
 
 
-def with_queues(**queues):
-
-    def patch_fun(fun):
-
-        @wraps(fun)
-        def __inner(*args, **kwargs):
-            app = current_app
-            prev_queues = app.conf.CELERY_QUEUES
-            prev_Queues = app.amqp.queues
-            app.conf.CELERY_QUEUES = queues
-            app.amqp.queues = app.amqp.Queues(queues)
-            try:
-                return fun(*args, **kwargs)
-            finally:
-                app.conf.CELERY_QUEUES = prev_queues
-                app.amqp.queues = prev_Queues
-        return __inner
-    return patch_fun
-
-
-a_queue = {'exchange': 'fooexchange',
-           'exchange_type': 'fanout',
-           'routing_key': 'xuzzy'}
-b_queue = {'exchange': 'barexchange',
-           'exchange_type': 'topic',
-           'routing_key': 'b.b.#'}
-d_queue = {'exchange': current_app.conf.CELERY_DEFAULT_EXCHANGE,
-           'exchange_type': current_app.conf.CELERY_DEFAULT_EXCHANGE_TYPE,
-           'routing_key': current_app.conf.CELERY_DEFAULT_ROUTING_KEY}
-
-
-class RouteCase(Case):
-    pass
+@contextmanager
+def _queues(app, **queues):
+    prev_queues = app.conf.CELERY_QUEUES
+    prev_Queues = app.amqp.queues
+    app.conf.CELERY_QUEUES = queues
+    app.amqp.queues = app.amqp.Queues(queues)
+    try:
+        yield
+    finally:
+        app.conf.CELERY_QUEUES = prev_queues
+        app.amqp.queues = prev_Queues
+
+
+class RouteCase(AppCase):
+
+    def setup(self):
+        self.a_queue = {
+            'exchange': 'fooexchange',
+            'exchange_type': 'fanout',
+            'routing_key': 'xuzzy',
+        }
+        self.b_queue = {
+            'exchange': 'barexchange',
+            'exchange_type': 'topic',
+            'routing_key': 'b.b.#',
+        }
+        self.d_queue = {
+            'exchange': self.app.conf.CELERY_DEFAULT_EXCHANGE,
+            'exchange_type': self.app.conf.CELERY_DEFAULT_EXCHANGE_TYPE,
+            'routing_key': self.app.conf.CELERY_DEFAULT_ROUTING_KEY,
+        }
 
 
 
 
 class test_MapRoute(RouteCase):
 class test_MapRoute(RouteCase):
 
 
-    @with_queues(foo=a_queue, bar=b_queue)
     def test_route_for_task_expanded_route(self):
     def test_route_for_task_expanded_route(self):
-        expand = E(current_app.amqp.queues)
-        route = routes.MapRoute({mytask.name: {'queue': 'foo'}})
-        self.assertEqual(
-            expand(route.route_for_task(mytask.name))['queue'].name,
-            'foo',
-        )
-        self.assertIsNone(route.route_for_task('celery.awesome'))
-
-    @with_queues(foo=a_queue, bar=b_queue)
+        with _queues(self.app, foo=self.a_queue, bar=self.b_queue):
+            expand = E(self.app, self.app.amqp.queues)
+            route = routes.MapRoute({mytask.name: {'queue': 'foo'}})
+            self.assertEqual(
+                expand(route.route_for_task(mytask.name))['queue'].name,
+                'foo',
+            )
+            self.assertIsNone(route.route_for_task('celery.awesome'))
+
     def test_route_for_task(self):
     def test_route_for_task(self):
-        expand = E(current_app.amqp.queues)
-        route = routes.MapRoute({mytask.name: b_queue})
-        self.assertDictContainsSubset(
-            b_queue,
-            expand(route.route_for_task(mytask.name)),
-        )
-        self.assertIsNone(route.route_for_task('celery.awesome'))
+        with _queues(self.app, foo=self.a_queue, bar=self.b_queue):
+            expand = E(self.app, self.app.amqp.queues)
+            route = routes.MapRoute({mytask.name: self.b_queue})
+            self.assertDictContainsSubset(
+                self.b_queue,
+                expand(route.route_for_task(mytask.name)),
+            )
+            self.assertIsNone(route.route_for_task('celery.awesome'))
 
 
     def test_expand_route_not_found(self):
     def test_expand_route_not_found(self):
-        expand = E(current_app.amqp.Queues(
-            current_app.conf.CELERY_QUEUES, False))
+        expand = E(self.app, self.app.amqp.Queues(
+                   self.app.conf.CELERY_QUEUES, False))
         route = routes.MapRoute({'a': {'queue': 'x'}})
         route = routes.MapRoute({'a': {'queue': 'x'}})
         with self.assertRaises(QueueNotFound):
         with self.assertRaises(QueueNotFound):
             expand(route.route_for_task('a'))
             expand(route.route_for_task('a'))
@@ -95,55 +92,59 @@ class test_MapRoute(RouteCase):
 class test_lookup_route(RouteCase):
 class test_lookup_route(RouteCase):
 
 
     def test_init_queues(self):
     def test_init_queues(self):
-        router = Router(queues=None)
+        router = Router(self.app, queues=None)
         self.assertDictEqual(router.queues, {})
         self.assertDictEqual(router.queues, {})
 
 
-    @with_queues(foo=a_queue, bar=b_queue)
     def test_lookup_takes_first(self):
     def test_lookup_takes_first(self):
-        R = routes.prepare(({mytask.name: {'queue': 'bar'}},
-                            {mytask.name: {'queue': 'foo'}}))
-        router = Router(R, current_app.amqp.queues)
-        self.assertEqual(router.route({}, mytask.name,
-                         args=[1, 2], kwargs={})['queue'].name, 'bar')
+        with _queues(self.app, foo=self.a_queue, bar=self.b_queue):
+            R = routes.prepare(({mytask.name: {'queue': 'bar'}},
+                                {mytask.name: {'queue': 'foo'}}))
+            router = Router(self.app, R, self.app.amqp.queues)
+            self.assertEqual(router.route({}, mytask.name,
+                             args=[1, 2], kwargs={})['queue'].name, 'bar')
 
 
-    @with_queues()
     def test_expands_queue_in_options(self):
     def test_expands_queue_in_options(self):
-        R = routes.prepare(())
-        router = Router(R, current_app.amqp.queues, create_missing=True)
-        # apply_async forwards all arguments, even exchange=None etc,
-        # so need to make sure it's merged correctly.
-        route = router.route({'queue': 'testq',
-                              'exchange': None,
-                              'routing_key': None,
-                              'immediate': False},
-                             mytask.name,
-                             args=[1, 2], kwargs={})
-        self.assertEqual(route['queue'].name, 'testq')
-        self.assertEqual(route['queue'].exchange, Exchange('testq'))
-        self.assertEqual(route['queue'].routing_key, 'testq')
-        self.assertEqual(route['immediate'], False)
-
-    @with_queues(foo=a_queue, bar=b_queue)
+        with _queues(self.app):
+            R = routes.prepare(())
+            router = Router(self.app, R, self.app.amqp.queues, create_missing=True)
+            # apply_async forwards all arguments, even exchange=None etc,
+            # so need to make sure it's merged correctly.
+            route = router.route(
+                {'queue': 'testq',
+                 'exchange': None,
+                 'routing_key': None,
+                 'immediate': False},
+                mytask.name,
+                args=[1, 2], kwargs={},
+            )
+            self.assertEqual(route['queue'].name, 'testq')
+            self.assertEqual(route['queue'].exchange, Exchange('testq'))
+            self.assertEqual(route['queue'].routing_key, 'testq')
+            self.assertEqual(route['immediate'], False)
+
     def test_expand_destination_string(self):
     def test_expand_destination_string(self):
-        x = Router({}, current_app.amqp.queues)
-        dest = x.expand_destination('foo')
-        self.assertEqual(dest['queue'].name, 'foo')
+        with _queues(self.app, foo=self.a_queue, bar=self.b_queue):
+            x = Router(self.app, {}, self.app.amqp.queues)
+            dest = x.expand_destination('foo')
+            self.assertEqual(dest['queue'].name, 'foo')
 
 
-    @with_queues(foo=a_queue, bar=b_queue, **{
-        current_app.conf.CELERY_DEFAULT_QUEUE: d_queue})
     def test_lookup_paths_traversed(self):
     def test_lookup_paths_traversed(self):
-        R = routes.prepare(({'celery.xaza': {'queue': 'bar'}},
-                            {mytask.name: {'queue': 'foo'}}))
-        router = Router(R, current_app.amqp.queues)
-        self.assertEqual(router.route({}, mytask.name,
-                         args=[1, 2], kwargs={})['queue'].name, 'foo')
-        self.assertEqual(
-            router.route({}, 'celery.poza')['queue'].name,
-            current_app.conf.CELERY_DEFAULT_QUEUE,
-        )
-
-
-class test_prepare(Case):
+        with _queues(self.app, foo=self.a_queue, bar=self.b_queue, **{
+                self.app.conf.CELERY_DEFAULT_QUEUE: self.d_queue}):
+            R = routes.prepare((
+                {'celery.xaza': {'queue': 'bar'}},
+                {mytask.name: {'queue': 'foo'}}
+            ))
+            router = Router(self.app, R, self.app.amqp.queues)
+            self.assertEqual(router.route({}, mytask.name,
+                             args=[1, 2], kwargs={})['queue'].name, 'foo')
+            self.assertEqual(
+                router.route({}, 'celery.poza')['queue'].name,
+                self.app.conf.CELERY_DEFAULT_QUEUE,
+            )
+
+
+class test_prepare(AppCase):
 
 
     def test_prepare(self):
     def test_prepare(self):
         from celery.datastructures import LRUCache
         from celery.datastructures import LRUCache

+ 1 - 2
celery/tests/utilities/test_platforms.py

@@ -7,7 +7,6 @@ import signal
 
 
 from mock import Mock, patch
 from mock import Mock, patch
 
 
-from celery import current_app
 from celery import platforms
 from celery import platforms
 from celery.five import open_fqdn
 from celery.five import open_fqdn
 from celery.platforms import (
 from celery.platforms import (
@@ -97,7 +96,7 @@ class test_Signals(Case):
         signals['INT'] = lambda *a: a
         signals['INT'] = lambda *a: a
 
 
 
 
-if not current_app.IS_WINDOWS:
+if not platforms.IS_WINDOWS:
 
 
     class test_get_fdmax(Case):
     class test_get_fdmax(Case):
 
 

+ 95 - 33
celery/tests/worker/test_control.py

@@ -9,7 +9,6 @@ from datetime import datetime, timedelta
 from kombu import pidbox
 from kombu import pidbox
 from mock import Mock, patch, call
 from mock import Mock, patch, call
 
 
-from celery import current_app
 from celery.datastructures import AttributeDict
 from celery.datastructures import AttributeDict
 from celery.task import task
 from celery.task import task
 from celery.utils import uuid
 from celery.utils import uuid
@@ -22,7 +21,8 @@ from celery.five import Queue as FastQueue
 from celery.worker.job import TaskRequest
 from celery.worker.job import TaskRequest
 from celery.worker.state import revoked
 from celery.worker.state import revoked
 from celery.worker.control import Panel
 from celery.worker.control import Panel
-from celery.tests.utils import Case
+from celery.worker.pidbox import Pidbox, gPidbox
+from celery.tests.utils import AppCase, Case
 
 
 hostname = socket.gethostname()
 hostname = socket.gethostname()
 
 
@@ -41,11 +41,11 @@ class WorkController(object):
 
 
 class Consumer(consumer.Consumer):
 class Consumer(consumer.Consumer):
 
 
-    def __init__(self):
+    def __init__(self, app):
+        self.app = app
         self.buffer = FastQueue()
         self.buffer = FastQueue()
         self.handle_task = self.buffer.put
         self.handle_task = self.buffer.put
         self.timer = Timer()
         self.timer = Timer()
-        self.app = current_app
         self.event_dispatcher = Mock()
         self.event_dispatcher = Mock()
         self.controller = WorkController()
         self.controller = WorkController()
         self.task_consumer = Mock()
         self.task_consumer = Mock()
@@ -55,11 +55,73 @@ class Consumer(consumer.Consumer):
         self.task_buckets = defaultdict(lambda: None)
         self.task_buckets = defaultdict(lambda: None)
 
 
 
 
-class test_ControlPanel(Case):
+class test_Pidbox(AppCase):
 
 
-    def setUp(self):
-        self.app = current_app
-        self.panel = self.create_panel(consumer=Consumer())
+    def test_shutdown(self):
+        with patch('celery.worker.pidbox.ignore_errors') as eig:
+            parent = Mock()
+            pidbox = Pidbox(parent)
+            pidbox._close_channel = Mock()
+            self.assertIs(pidbox.c, parent)
+            pconsumer = pidbox.consumer = Mock()
+            cancel = pconsumer.cancel
+            pidbox.shutdown(parent)
+            eig.assert_called_with(parent, cancel)
+            pidbox._close_channel.assert_called_with(parent)
+
+
+class test_Pidbox_green(AppCase):
+
+    def test_stop(self):
+        parent = Mock()
+        g = gPidbox(parent)
+        stopped = g._node_stopped = Mock()
+        shutdown = g._node_shutdown = Mock()
+        close_chan = g._close_channel = Mock()
+
+        g.stop(parent)
+        shutdown.set.assert_called_with()
+        stopped.wait.assert_called_with()
+        close_chan.assert_called_with(parent)
+        self.assertIsNone(g._node_stopped)
+        self.assertIsNone(g._node_shutdown)
+
+        close_chan.reset()
+        g.stop(parent)
+        close_chan.assert_called_with(parent)
+
+    def test_resets(self):
+        parent = Mock()
+        g = gPidbox(parent)
+        g._resets = 100
+        g.reset()
+        self.assertEqual(g._resets, 101)
+
+    def test_loop(self):
+        parent = Mock()
+        conn = parent.connect.return_value = self.app.connection()
+        drain = conn.drain_events = Mock()
+        g = gPidbox(parent)
+        parent.connection = Mock()
+        do_reset = g._do_reset = Mock()
+
+        call_count = [0]
+
+        def se(*args, **kwargs):
+            if call_count[0] > 2:
+                g._node_shutdown.set()
+            g.reset()
+            call_count[0] += 1
+        drain.side_effect = se
+        g.loop(parent)
+
+        self.assertEqual(do_reset.call_count, 4)
+
+
+class test_ControlPanel(AppCase):
+
+    def setup(self):
+        self.panel = self.create_panel(consumer=Consumer(self.app))
 
 
     def create_state(self, **kwargs):
     def create_state(self, **kwargs):
         kwargs.setdefault('app', self.app)
         kwargs.setdefault('app', self.app)
@@ -71,7 +133,7 @@ class test_ControlPanel(Case):
                                              handlers=Panel.data)
                                              handlers=Panel.data)
 
 
     def test_enable_events(self):
     def test_enable_events(self):
-        consumer = Consumer()
+        consumer = Consumer(self.app)
         panel = self.create_panel(consumer=consumer)
         panel = self.create_panel(consumer=consumer)
         evd = consumer.event_dispatcher
         evd = consumer.event_dispatcher
         evd.groups = set()
         evd.groups = set()
@@ -81,7 +143,7 @@ class test_ControlPanel(Case):
         self.assertIn('already enabled', panel.handle('enable_events')['ok'])
         self.assertIn('already enabled', panel.handle('enable_events')['ok'])
 
 
     def test_disable_events(self):
     def test_disable_events(self):
-        consumer = Consumer()
+        consumer = Consumer(self.app)
         panel = self.create_panel(consumer=consumer)
         panel = self.create_panel(consumer=consumer)
         evd = consumer.event_dispatcher
         evd = consumer.event_dispatcher
         evd.enabled = True
         evd.enabled = True
@@ -91,7 +153,7 @@ class test_ControlPanel(Case):
         self.assertIn('already disabled', panel.handle('disable_events')['ok'])
         self.assertIn('already disabled', panel.handle('disable_events')['ok'])
 
 
     def test_heartbeat(self):
     def test_heartbeat(self):
-        consumer = Consumer()
+        consumer = Consumer(self.app)
         panel = self.create_panel(consumer=consumer)
         panel = self.create_panel(consumer=consumer)
         consumer.event_dispatcher.enabled = True
         consumer.event_dispatcher.enabled = True
         panel.handle('heartbeat')
         panel.handle('heartbeat')
@@ -122,7 +184,7 @@ class test_ControlPanel(Case):
     def test_active_queues(self):
     def test_active_queues(self):
         import kombu
         import kombu
 
 
-        x = kombu.Consumer(current_app.connection(),
+        x = kombu.Consumer(self.app.connection(),
                            [kombu.Queue('foo', kombu.Exchange('foo'), 'foo'),
                            [kombu.Queue('foo', kombu.Exchange('foo'), 'foo'),
                             kombu.Queue('bar', kombu.Exchange('bar'), 'bar')],
                             kombu.Queue('bar', kombu.Exchange('bar'), 'bar')],
                            auto_declare=False)
                            auto_declare=False)
@@ -170,7 +232,7 @@ class test_ControlPanel(Case):
             def shrink(self, n=1):
             def shrink(self, n=1):
                 self.size -= n
                 self.size -= n
 
 
-        consumer = Consumer()
+        consumer = Consumer(self.app)
         consumer.pool = MockPool()
         consumer.pool = MockPool()
         panel = self.create_panel(consumer=consumer)
         panel = self.create_panel(consumer=consumer)
 
 
@@ -206,7 +268,7 @@ class test_ControlPanel(Case):
             def consuming_from(self, queue):
             def consuming_from(self, queue):
                 return queue in self.queues
                 return queue in self.queues
 
 
-        consumer = Consumer()
+        consumer = Consumer(self.app)
         consumer.task_consumer = MockConsumer()
         consumer.task_consumer = MockConsumer()
         panel = self.create_panel(consumer=consumer)
         panel = self.create_panel(consumer=consumer)
 
 
@@ -229,7 +291,7 @@ class test_ControlPanel(Case):
             state.revoked.clear()
             state.revoked.clear()
 
 
     def test_dump_schedule(self):
     def test_dump_schedule(self):
-        consumer = Consumer()
+        consumer = Consumer(self.app)
         panel = self.create_panel(consumer=consumer)
         panel = self.create_panel(consumer=consumer)
         self.assertFalse(panel.handle('dump_schedule'))
         self.assertFalse(panel.handle('dump_schedule'))
         r = TaskRequest(mytask.name, 'CAFEBABE', (), {})
         r = TaskRequest(mytask.name, 'CAFEBABE', (), {})
@@ -240,7 +302,7 @@ class test_ControlPanel(Case):
 
 
     def test_dump_reserved(self):
     def test_dump_reserved(self):
         from celery.worker import state
         from celery.worker import state
-        consumer = Consumer()
+        consumer = Consumer(self.app)
         state.reserved_requests.add(
         state.reserved_requests.add(
             TaskRequest(mytask.name, uuid(), args=(2, 2), kwargs={}),
             TaskRequest(mytask.name, uuid(), args=(2, 2), kwargs={}),
         )
         )
@@ -266,16 +328,16 @@ class test_ControlPanel(Case):
 
 
     def test_rate_limit(self):
     def test_rate_limit(self):
 
 
-        class Consumer(object):
+        class xConsumer(object):
             reset = False
             reset = False
 
 
             def reset_rate_limits(self):
             def reset_rate_limits(self):
                 self.reset = True
                 self.reset = True
 
 
-        consumer = Consumer()
-        panel = self.create_panel(app=current_app, consumer=consumer)
+        consumer = xConsumer()
+        panel = self.create_panel(app=self.app, consumer=consumer)
 
 
-        task = current_app.tasks[mytask.name]
+        task = self.app.tasks[mytask.name]
         old_rate_limit = task.rate_limit
         old_rate_limit = task.rate_limit
         try:
         try:
             panel.handle('rate_limit', arguments=dict(task_name=task.name,
             panel.handle('rate_limit', arguments=dict(task_name=task.name,
@@ -383,7 +445,7 @@ class test_ControlPanel(Case):
                 replies.append(data)
                 replies.append(data)
 
 
         panel = _Node(hostname=hostname,
         panel = _Node(hostname=hostname,
-                      state=self.create_state(consumer=Consumer()),
+                      state=self.create_state(consumer=Consumer(self.app)),
                       handlers=Panel.data,
                       handlers=Panel.data,
                       mailbox=self.app.control.mailbox)
                       mailbox=self.app.control.mailbox)
         r = panel.dispatch('ping', reply_to={'exchange': 'x',
         r = panel.dispatch('ping', reply_to={'exchange': 'x',
@@ -392,8 +454,8 @@ class test_ControlPanel(Case):
         self.assertDictEqual(replies[0], {panel.hostname: {'ok': 'pong'}})
         self.assertDictEqual(replies[0], {panel.hostname: {'ok': 'pong'}})
 
 
     def test_pool_restart(self):
     def test_pool_restart(self):
-        consumer = Consumer()
-        consumer.controller = _WC(app=current_app)
+        consumer = Consumer(self.app)
+        consumer.controller = _WC(app=self.app)
         consumer.controller.pool.restart = Mock()
         consumer.controller.pool.restart = Mock()
         panel = self.create_panel(consumer=consumer)
         panel = self.create_panel(consumer=consumer)
         panel.app = self.app
         panel.app = self.app
@@ -403,25 +465,25 @@ class test_ControlPanel(Case):
         with self.assertRaises(ValueError):
         with self.assertRaises(ValueError):
             panel.handle('pool_restart', {'reloader': _reload})
             panel.handle('pool_restart', {'reloader': _reload})
 
 
-        current_app.conf.CELERYD_POOL_RESTARTS = True
+        self.app.conf.CELERYD_POOL_RESTARTS = True
         try:
         try:
             panel.handle('pool_restart', {'reloader': _reload})
             panel.handle('pool_restart', {'reloader': _reload})
             self.assertTrue(consumer.controller.pool.restart.called)
             self.assertTrue(consumer.controller.pool.restart.called)
             self.assertFalse(_reload.called)
             self.assertFalse(_reload.called)
             self.assertFalse(_import.called)
             self.assertFalse(_import.called)
         finally:
         finally:
-            current_app.conf.CELERYD_POOL_RESTARTS = False
+            self.app.conf.CELERYD_POOL_RESTARTS = False
 
 
     def test_pool_restart_import_modules(self):
     def test_pool_restart_import_modules(self):
-        consumer = Consumer()
-        consumer.controller = _WC(app=current_app)
+        consumer = Consumer(self.app)
+        consumer.controller = _WC(app=self.app)
         consumer.controller.pool.restart = Mock()
         consumer.controller.pool.restart = Mock()
         panel = self.create_panel(consumer=consumer)
         panel = self.create_panel(consumer=consumer)
         panel.app = self.app
         panel.app = self.app
         _import = consumer.controller.app.loader.import_from_cwd = Mock()
         _import = consumer.controller.app.loader.import_from_cwd = Mock()
         _reload = Mock()
         _reload = Mock()
 
 
-        current_app.conf.CELERYD_POOL_RESTARTS = True
+        self.app.conf.CELERYD_POOL_RESTARTS = True
         try:
         try:
             panel.handle('pool_restart', {'modules': ['foo', 'bar'],
             panel.handle('pool_restart', {'modules': ['foo', 'bar'],
                                           'reloader': _reload})
                                           'reloader': _reload})
@@ -433,18 +495,18 @@ class test_ControlPanel(Case):
                 _import.call_args_list,
                 _import.call_args_list,
             )
             )
         finally:
         finally:
-            current_app.conf.CELERYD_POOL_RESTARTS = False
+            self.app.conf.CELERYD_POOL_RESTARTS = False
 
 
     def test_pool_restart_reload_modules(self):
     def test_pool_restart_reload_modules(self):
-        consumer = Consumer()
-        consumer.controller = _WC(app=current_app)
+        consumer = Consumer(self.app)
+        consumer.controller = _WC(app=self.app)
         consumer.controller.pool.restart = Mock()
         consumer.controller.pool.restart = Mock()
         panel = self.create_panel(consumer=consumer)
         panel = self.create_panel(consumer=consumer)
         panel.app = self.app
         panel.app = self.app
         _import = panel.app.loader.import_from_cwd = Mock()
         _import = panel.app.loader.import_from_cwd = Mock()
         _reload = Mock()
         _reload = Mock()
 
 
-        current_app.conf.CELERYD_POOL_RESTARTS = True
+        self.app.conf.CELERYD_POOL_RESTARTS = True
         try:
         try:
             with patch.dict(sys.modules, {'foo': None}):
             with patch.dict(sys.modules, {'foo': None}):
                 panel.handle('pool_restart', {'modules': ['foo'],
                 panel.handle('pool_restart', {'modules': ['foo'],
@@ -467,4 +529,4 @@ class test_ControlPanel(Case):
                 self.assertTrue(_reload.called)
                 self.assertTrue(_reload.called)
                 self.assertFalse(_import.called)
                 self.assertFalse(_import.called)
         finally:
         finally:
-            current_app.conf.CELERYD_POOL_RESTARTS = False
+            self.app.conf.CELERYD_POOL_RESTARTS = False

+ 15 - 17
celery/tests/worker/test_request.py

@@ -14,9 +14,7 @@ from kombu.utils.encoding import from_utf8, default_encode
 from mock import Mock, patch
 from mock import Mock, patch
 from nose import SkipTest
 from nose import SkipTest
 
 
-from celery import current_app
 from celery import states
 from celery import states
-from celery.app import app_or_default
 from celery.concurrency.base import BasePool
 from celery.concurrency.base import BasePool
 from celery.datastructures import ExceptionInfo
 from celery.datastructures import ExceptionInfo
 from celery.exceptions import (
 from celery.exceptions import (
@@ -79,9 +77,9 @@ class test_mro_lookup(Case):
         self.assertIsNone(mro_lookup(D, 'x'))
         self.assertIsNone(mro_lookup(D, 'x'))
 
 
 
 
-def jail(task_id, name, args, kwargs):
+def jail(app, task_id, name, args, kwargs):
     request = {'id': task_id}
     request = {'id': task_id}
-    task = current_app.tasks[name]
+    task = app.tasks[name]
     task.__trace__ = None  # rebuild
     task.__trace__ = None  # rebuild
     return trace_task(
     return trace_task(
         task, task_id, args, kwargs, request=request, eager=False,
         task, task_id, args, kwargs, request=request, eager=False,
@@ -120,9 +118,9 @@ def mytask_raising(i):
     raise KeyError(i)
     raise KeyError(i)
 
 
 
 
-class test_default_encode(Case):
+class test_default_encode(AppCase):
 
 
-    def setUp(self):
+    def setup(self):
         if sys.version_info >= (3, 0):
         if sys.version_info >= (3, 0):
             raise SkipTest('py3k: not relevant')
             raise SkipTest('py3k: not relevant')
 
 
@@ -146,7 +144,7 @@ class test_default_encode(Case):
             sys.getfilesystemencoding = gfe
             sys.getfilesystemencoding = gfe
 
 
 
 
-class test_RetryTaskError(Case):
+class test_RetryTaskError(AppCase):
 
 
     def test_retry_task_error(self):
     def test_retry_task_error(self):
         try:
         try:
@@ -156,7 +154,7 @@ class test_RetryTaskError(Case):
             self.assertEqual(ret.exc, exc)
             self.assertEqual(ret.exc, exc)
 
 
 
 
-class test_trace_task(Case):
+class test_trace_task(AppCase):
 
 
     @patch('celery.task.trace._logger')
     @patch('celery.task.trace._logger')
     def test_process_cleanup_fails(self, _logger):
     def test_process_cleanup_fails(self, _logger):
@@ -165,7 +163,7 @@ class test_trace_task(Case):
         mytask.backend.process_cleanup = Mock(side_effect=KeyError())
         mytask.backend.process_cleanup = Mock(side_effect=KeyError())
         try:
         try:
             tid = uuid()
             tid = uuid()
-            ret = jail(tid, mytask.name, [2], {})
+            ret = jail(self.app, tid, mytask.name, [2], {})
             self.assertEqual(ret, 4)
             self.assertEqual(ret, 4)
             mytask.backend.store_result.assert_called_with(tid, 4,
             mytask.backend.store_result.assert_called_with(tid, 4,
                                                            states.SUCCESS)
                                                            states.SUCCESS)
@@ -180,12 +178,12 @@ class test_trace_task(Case):
         mytask.backend.process_cleanup = Mock(side_effect=SystemExit())
         mytask.backend.process_cleanup = Mock(side_effect=SystemExit())
         try:
         try:
             with self.assertRaises(SystemExit):
             with self.assertRaises(SystemExit):
-                jail(uuid(), mytask.name, [2], {})
+                jail(self.app, uuid(), mytask.name, [2], {})
         finally:
         finally:
             mytask.backend = backend
             mytask.backend = backend
 
 
     def test_execute_jail_success(self):
     def test_execute_jail_success(self):
-        ret = jail(uuid(), mytask.name, [2], {})
+        ret = jail(self.app, uuid(), mytask.name, [2], {})
         self.assertEqual(ret, 4)
         self.assertEqual(ret, 4)
 
 
     def test_marked_as_started(self):
     def test_marked_as_started(self):
@@ -202,12 +200,12 @@ class test_trace_task(Case):
 
 
         try:
         try:
             tid = uuid()
             tid = uuid()
-            jail(tid, mytask.name, [2], {})
+            jail(self.app, tid, mytask.name, [2], {})
             self.assertIn(tid, Backend._started)
             self.assertIn(tid, Backend._started)
 
 
             mytask.ignore_result = True
             mytask.ignore_result = True
             tid = uuid()
             tid = uuid()
-            jail(tid, mytask.name, [2], {})
+            jail(self.app, tid, mytask.name, [2], {})
             self.assertNotIn(tid, Backend._started)
             self.assertNotIn(tid, Backend._started)
         finally:
         finally:
             mytask.backend = prev
             mytask.backend = prev
@@ -215,14 +213,14 @@ class test_trace_task(Case):
             mytask.ignore_result = False
             mytask.ignore_result = False
 
 
     def test_execute_jail_failure(self):
     def test_execute_jail_failure(self):
-        ret = jail(uuid(), mytask_raising.name,
+        ret = jail(self.app, uuid(), mytask_raising.name,
                    [4], {})
                    [4], {})
         self.assertIsInstance(ret, ExceptionInfo)
         self.assertIsInstance(ret, ExceptionInfo)
         self.assertTupleEqual(ret.exception.args, (4, ))
         self.assertTupleEqual(ret.exception.args, (4, ))
 
 
     def test_execute_ignore_result(self):
     def test_execute_ignore_result(self):
         task_id = uuid()
         task_id = uuid()
-        ret = jail(task_id, MyTaskIgnoreResult.name, [4], {})
+        ret = jail(self.app, task_id, MyTaskIgnoreResult.name, [4], {})
         self.assertEqual(ret, 256)
         self.assertEqual(ret, 256)
         self.assertFalse(AsyncResult(task_id).ready())
         self.assertFalse(AsyncResult(task_id).ready())
 
 
@@ -355,7 +353,7 @@ class test_TaskRequest(AppCase):
             mytask.ignore_result = False
             mytask.ignore_result = False
 
 
     def test_send_email(self):
     def test_send_email(self):
-        app = app_or_default()
+        app = self.app
         old_mail_admins = app.mail_admins
         old_mail_admins = app.mail_admins
         old_enable_mails = mytask.send_error_emails
         old_enable_mails = mytask.send_error_emails
         mail_sent = [False]
         mail_sent = [False]
@@ -814,7 +812,7 @@ class test_TaskRequest(AppCase):
 
 
     @patch('celery.worker.job.logger')
     @patch('celery.worker.job.logger')
     def _test_on_failure(self, exception, logger):
     def _test_on_failure(self, exception, logger):
-        app = app_or_default()
+        app = self.app
         tid = uuid()
         tid = uuid()
         tw = TaskRequest(mytask.name, tid, [4], {'f': 'x'})
         tw = TaskRequest(mytask.name, tid, [4], {'f': 'x'})
         try:
         try:

+ 6 - 7
celery/tests/worker/test_worker.py

@@ -14,7 +14,6 @@ from kombu.exceptions import StdChannelError
 from kombu.transport.base import Message
 from kombu.transport.base import Message
 from mock import call, Mock, patch
 from mock import call, Mock, patch
 
 
-from celery import current_app
 from celery.app.defaults import DEFAULTS
 from celery.app.defaults import DEFAULTS
 from celery.bootsteps import RUN, CLOSE, TERMINATE, StartStopStep
 from celery.bootsteps import RUN, CLOSE, TERMINATE, StartStopStep
 from celery.concurrency.base import BasePool
 from celery.concurrency.base import BasePool
@@ -233,13 +232,13 @@ class test_QoS(Case):
         qos.set(qos.prev)
         qos.set(qos.prev)
 
 
 
 
-class test_Consumer(Case):
+class test_Consumer(AppCase):
 
 
-    def setUp(self):
+    def setup(self):
         self.buffer = FastQueue()
         self.buffer = FastQueue()
         self.timer = Timer()
         self.timer = Timer()
 
 
-    def tearDown(self):
+    def teardown(self):
         self.timer.stop()
         self.timer.stop()
 
 
     def test_info(self):
     def test_info(self):
@@ -432,7 +431,7 @@ class test_Consumer(Case):
 
 
     def test_loop_ignores_socket_timeout(self):
     def test_loop_ignores_socket_timeout(self):
 
 
-        class Connection(current_app.connection().__class__):
+        class Connection(self.app.connection().__class__):
             obj = None
             obj = None
 
 
             def drain_events(self, **kwargs):
             def drain_events(self, **kwargs):
@@ -448,7 +447,7 @@ class test_Consumer(Case):
 
 
     def test_loop_when_socket_error(self):
     def test_loop_when_socket_error(self):
 
 
-        class Connection(current_app.connection().__class__):
+        class Connection(self.app.connection().__class__):
             obj = None
             obj = None
 
 
             def drain_events(self, **kwargs):
             def drain_events(self, **kwargs):
@@ -470,7 +469,7 @@ class test_Consumer(Case):
 
 
     def test_loop(self):
     def test_loop(self):
 
 
-        class Connection(current_app.connection().__class__):
+        class Connection(self.app.connection().__class__):
             obj = None
             obj = None
 
 
             def drain_events(self, **kwargs):
             def drain_events(self, **kwargs):