Prechádzať zdrojové kódy

100% coverage for celery.worker.pidbox

Ask Solem 12 rokov pred
rodič
commit
14f76ec2a5

+ 44 - 41
celery/tests/app/test_beat.py

@@ -7,15 +7,13 @@ from mock import Mock, call, patch
 from nose import SkipTest
 from pickle import dumps, loads
 
-from celery import current_app
 from celery import beat
 from celery import task
 from celery.five import keys, string_t
 from celery.result import AsyncResult
 from celery.schedules import schedule
-from celery.task.base import Task
 from celery.utils import uuid
-from celery.tests.utils import Case, patch_settings
+from celery.tests.utils import AppCase, patch_settings
 
 
 class Object(object):
@@ -47,7 +45,7 @@ class MockService(object):
         self.stopped = True
 
 
-class test_ScheduleEntry(Case):
+class test_ScheduleEntry(AppCase):
     Entry = beat.ScheduleEntry
 
     def create_entry(self, **kwargs):
@@ -113,7 +111,7 @@ class mScheduler(beat.Scheduler):
                           'args': args,
                           'kwargs': kwargs,
                           'options': options})
-        return AsyncResult(uuid())
+        return AsyncResult(uuid(), app=self.app)
 
 
 class mSchedulerSchedulingError(mScheduler):
@@ -144,17 +142,17 @@ always_due = mocked_schedule(True, 1)
 always_pending = mocked_schedule(False, 1)
 
 
-class test_Scheduler(Case):
+class test_Scheduler(AppCase):
 
     def test_custom_schedule_dict(self):
         custom = {'foo': 'bar'}
-        scheduler = mScheduler(schedule=custom, lazy=True)
+        scheduler = mScheduler(app=self.app, schedule=custom, lazy=True)
         self.assertIs(scheduler.data, custom)
 
     def test_apply_async_uses_registered_task_instances(self):
         through_task = [False]
 
-        class MockTask(Task):
+        class MockTask(self.app.Task):
 
             @classmethod
             def apply_async(cls, *args, **kwargs):
@@ -162,7 +160,7 @@ class test_Scheduler(Case):
 
         assert MockTask.name in MockTask._get_app().tasks
 
-        scheduler = mScheduler()
+        scheduler = mScheduler(app=self.app)
         scheduler.apply_async(scheduler.Entry(task=MockTask.name))
         self.assertTrue(through_task[0])
 
@@ -173,7 +171,7 @@ class test_Scheduler(Case):
             pass
         not_sync.apply_async = Mock()
 
-        s = mScheduler()
+        s = mScheduler(app=self.app)
         s._do_sync = Mock()
         s.should_sync = Mock()
         s.should_sync.return_value = True
@@ -187,16 +185,16 @@ class test_Scheduler(Case):
 
     @patch('celery.app.base.Celery.send_task')
     def test_send_task(self, send_task):
-        b = beat.Scheduler()
+        b = beat.Scheduler(app=self.app)
         b.send_task('tasks.add', countdown=10)
         send_task.assert_called_with('tasks.add', countdown=10)
 
     def test_info(self):
-        scheduler = mScheduler()
+        scheduler = mScheduler(app=self.app)
         self.assertIsInstance(scheduler.info, string_t)
 
     def test_maybe_entry(self):
-        s = mScheduler()
+        s = mScheduler(app=self.app)
         entry = s.Entry(name='add every', task='tasks.add')
         self.assertIs(s._maybe_entry(entry.name, entry), entry)
         self.assertTrue(s._maybe_entry('add every', {
@@ -204,13 +202,13 @@ class test_Scheduler(Case):
         }))
 
     def test_set_schedule(self):
-        s = mScheduler()
+        s = mScheduler(app=self.app)
         s.schedule = {'foo': 'bar'}
         self.assertEqual(s.data, {'foo': 'bar'})
 
     @patch('kombu.connection.Connection.ensure_connection')
     def test_ensure_connection_error_handler(self, ensure):
-        s = mScheduler()
+        s = mScheduler(app=self.app)
         self.assertTrue(s._ensure_connected())
         self.assertTrue(ensure.called)
         callback = ensure.call_args[0][0]
@@ -218,29 +216,32 @@ class test_Scheduler(Case):
         callback(KeyError(), 5)
 
     def test_install_default_entries(self):
-        with patch_settings(CELERY_TASK_RESULT_EXPIRES=None,
+        with patch_settings(self.app,
+                            CELERY_TASK_RESULT_EXPIRES=None,
                             CELERYBEAT_SCHEDULE={}):
-            s = mScheduler()
+            s = mScheduler(app=self.app)
             s.install_default_entries({})
             self.assertNotIn('celery.backend_cleanup', s.data)
-        current_app.backend.supports_autoexpire = False
-        with patch_settings(CELERY_TASK_RESULT_EXPIRES=30,
+        self.app.backend.supports_autoexpire = False
+        with patch_settings(self.app,
+                            CELERY_TASK_RESULT_EXPIRES=30,
                             CELERYBEAT_SCHEDULE={}):
-            s = mScheduler()
+            s = mScheduler(app=self.app)
             s.install_default_entries({})
             self.assertIn('celery.backend_cleanup', s.data)
-        current_app.backend.supports_autoexpire = True
+        self.app.backend.supports_autoexpire = True
         try:
-            with patch_settings(CELERY_TASK_RESULT_EXPIRES=31,
+            with patch_settings(self.app,
+                                CELERY_TASK_RESULT_EXPIRES=31,
                                 CELERYBEAT_SCHEDULE={}):
-                s = mScheduler()
+                s = mScheduler(app=self.app)
                 s.install_default_entries({})
                 self.assertNotIn('celery.backend_cleanup', s.data)
         finally:
-            current_app.backend.supports_autoexpire = False
+            self.app.backend.supports_autoexpire = False
 
     def test_due_tick(self):
-        scheduler = mScheduler()
+        scheduler = mScheduler(app=self.app)
         scheduler.add(name='test_due_tick',
                       schedule=always_due,
                       args=(1, 2),
@@ -249,33 +250,33 @@ class test_Scheduler(Case):
 
     @patch('celery.beat.error')
     def test_due_tick_SchedulingError(self, error):
-        scheduler = mSchedulerSchedulingError()
+        scheduler = mSchedulerSchedulingError(app=self.app)
         scheduler.add(name='test_due_tick_SchedulingError',
                       schedule=always_due)
         self.assertEqual(scheduler.tick(), 1)
         self.assertTrue(error.called)
 
     def test_due_tick_RuntimeError(self):
-        scheduler = mSchedulerRuntimeError()
+        scheduler = mSchedulerRuntimeError(app=self.app)
         scheduler.add(name='test_due_tick_RuntimeError',
                       schedule=always_due)
         self.assertEqual(scheduler.tick(), scheduler.max_interval)
 
     def test_pending_tick(self):
-        scheduler = mScheduler()
+        scheduler = mScheduler(app=self.app)
         scheduler.add(name='test_pending_tick',
                       schedule=always_pending)
         self.assertEqual(scheduler.tick(), 1)
 
     def test_honors_max_interval(self):
-        scheduler = mScheduler()
+        scheduler = mScheduler(app=self.app)
         maxi = scheduler.max_interval
         scheduler.add(name='test_honors_max_interval',
                       schedule=mocked_schedule(False, maxi * 4))
         self.assertEqual(scheduler.tick(), maxi)
 
     def test_ticks(self):
-        scheduler = mScheduler()
+        scheduler = mScheduler(app=self.app)
         nums = [600, 300, 650, 120, 250, 36]
         s = dict(('test_ticks%s' % i,
                  {'schedule': mocked_schedule(False, j)})
@@ -284,20 +285,20 @@ class test_Scheduler(Case):
         self.assertEqual(scheduler.tick(), min(nums))
 
     def test_schedule_no_remain(self):
-        scheduler = mScheduler()
+        scheduler = mScheduler(app=self.app)
         scheduler.add(name='test_schedule_no_remain',
                       schedule=mocked_schedule(False, None))
         self.assertEqual(scheduler.tick(), scheduler.max_interval)
 
     def test_interface(self):
-        scheduler = mScheduler()
+        scheduler = mScheduler(app=self.app)
         scheduler.sync()
         scheduler.setup_schedule()
         scheduler.close()
 
     def test_merge_inplace(self):
-        a = mScheduler()
-        b = mScheduler()
+        a = mScheduler(app=self.app)
+        b = mScheduler(app=self.app)
         a.update_from_dict({'foo': {'schedule': mocked_schedule(True, 10)},
                             'bar': {'schedule': mocked_schedule(True, 20)}})
         b.update_from_dict({'bar': {'schedule': mocked_schedule(True, 40)},
@@ -330,11 +331,12 @@ def create_persistent_scheduler(shelv=None):
     return MockPersistentScheduler, shelv
 
 
-class test_PersistentScheduler(Case):
+class test_PersistentScheduler(AppCase):
 
     @patch('os.remove')
     def test_remove_db(self, remove):
-        s = create_persistent_scheduler()[0](schedule_filename='schedule')
+        s = create_persistent_scheduler()[0](app=self.app,
+                                             schedule_filename='schedule')
         s._remove_db()
         remove.assert_has_calls(
             [call('schedule' + suffix) for suffix in s.known_suffixes]
@@ -348,7 +350,8 @@ class test_PersistentScheduler(Case):
             s._remove_db()
 
     def test_setup_schedule(self):
-        s = create_persistent_scheduler()[0](schedule_filename='schedule')
+        s = create_persistent_scheduler()[0](app=self.app,
+                                             schedule_filename='schedule')
         opens = s.persistence.open = Mock()
         s._remove_db = Mock()
 
@@ -383,14 +386,14 @@ class test_PersistentScheduler(Case):
         self.assertDictEqual(s._store['entries'], s.schedule)
 
 
-class test_Service(Case):
+class test_Service(AppCase):
 
     def get_service(self):
         Scheduler, mock_shelve = create_persistent_scheduler()
-        return beat.Service(scheduler_cls=Scheduler), mock_shelve
+        return beat.Service(app=self.app, scheduler_cls=Scheduler), mock_shelve
 
     def test_pickleable(self):
-        s = beat.Service(scheduler_cls=Mock)
+        s = beat.Service(app=self.app, scheduler_cls=Mock)
         self.assertTrue(loads(dumps(s)))
 
     def test_start(self):
@@ -442,7 +445,7 @@ class test_Service(Case):
         self.assertTrue(s._is_shutdown.isSet())
 
 
-class test_EmbeddedService(Case):
+class test_EmbeddedService(AppCase):
 
     def test_start_stop_process(self):
         try:

+ 46 - 43
celery/tests/app/test_builtins.py

@@ -2,93 +2,97 @@ from __future__ import absolute_import
 
 from mock import Mock, patch
 
-from celery import current_app as app, group, task, chord
+from celery import group, shared_task, chord
 from celery.app import builtins
 from celery.canvas import Signature
 from celery.five import range
 from celery._state import _task_stack
-from celery.tests.utils import Case
+from celery.tests.utils import AppCase
 
 
-@task()
+@shared_task()
 def add(x, y):
     return x + y
 
 
-@task()
+@shared_task()
 def xsum(x):
     return sum(x)
 
 
-class test_backend_cleanup(Case):
+class test_backend_cleanup(AppCase):
 
     def test_run(self):
-        prev = app.backend
-        app.backend.cleanup = Mock()
-        app.backend.cleanup.__name__ = 'cleanup'
+        prev = self.app.backend
+        self.app.backend.cleanup = Mock()
+        self.app.backend.cleanup.__name__ = 'cleanup'
         try:
-            cleanup_task = builtins.add_backend_cleanup_task(app)
+            cleanup_task = builtins.add_backend_cleanup_task(self.app)
             cleanup_task()
-            self.assertTrue(app.backend.cleanup.called)
+            self.assertTrue(self.app.backend.cleanup.called)
         finally:
-            app.backend = prev
+            self.app.backend = prev
 
 
-class test_map(Case):
+class test_map(AppCase):
 
     def test_run(self):
 
-        @app.task()
+        @self.app.task()
         def map_mul(x):
             return x[0] * x[1]
 
-        res = app.tasks['celery.map'](map_mul, [(2, 2), (4, 4), (8, 8)])
+        res = self.app.tasks['celery.map'](
+            map_mul, [(2, 2), (4, 4), (8, 8)],
+        )
         self.assertEqual(res, [4, 16, 64])
 
 
-class test_starmap(Case):
+class test_starmap(AppCase):
 
     def test_run(self):
 
-        @app.task()
+        @self.app.task()
         def smap_mul(x, y):
             return x * y
 
-        res = app.tasks['celery.starmap'](smap_mul, [(2, 2), (4, 4), (8, 8)])
+        res = self.app.tasks['celery.starmap'](
+            smap_mul, [(2, 2), (4, 4), (8, 8)],
+        )
         self.assertEqual(res, [4, 16, 64])
 
 
-class test_chunks(Case):
+class test_chunks(AppCase):
 
     @patch('celery.canvas.chunks.apply_chunks')
     def test_run(self, apply_chunks):
 
-        @app.task()
+        @self.app.task()
         def chunks_mul(l):
             return l
 
-        app.tasks['celery.chunks'](
+        self.app.tasks['celery.chunks'](
             chunks_mul, [(2, 2), (4, 4), (8, 8)], 1,
         )
         self.assertTrue(apply_chunks.called)
 
 
-class test_group(Case):
+class test_group(AppCase):
 
-    def setUp(self):
-        self.prev = app.tasks.get('celery.group')
-        self.task = builtins.add_group_task(app)()
+    def setup(self):
+        self.prev = self.app.tasks.get('celery.group')
+        self.task = builtins.add_group_task(self.app)()
 
-    def tearDown(self):
-        app.tasks['celery.group'] = self.prev
+    def teardown(self):
+        self.app.tasks['celery.group'] = self.prev
 
     def test_apply_async_eager(self):
         self.task.apply = Mock()
-        app.conf.CELERY_ALWAYS_EAGER = True
+        self.app.conf.CELERY_ALWAYS_EAGER = True
         try:
             self.task.apply_async()
         finally:
-            app.conf.CELERY_ALWAYS_EAGER = False
+            self.app.conf.CELERY_ALWAYS_EAGER = False
         self.assertTrue(self.task.apply.called)
 
     def test_apply(self):
@@ -125,14 +129,14 @@ class test_group(Case):
             _task_stack.pop()
 
 
-class test_chain(Case):
+class test_chain(AppCase):
 
-    def setUp(self):
-        self.prev = app.tasks.get('celery.chain')
-        self.task = builtins.add_chain_task(app)()
+    def setup(self):
+        self.prev = self.app.tasks.get('celery.chain')
+        self.task = builtins.add_chain_task(self.app)()
 
-    def tearDown(self):
-        app.tasks['celery.chain'] = self.prev
+    def teardown(self):
+        self.app.tasks['celery.chain'] = self.prev
 
     def test_apply_async(self):
         c = add.s(2, 2) | add.s(4) | add.s(8)
@@ -185,14 +189,14 @@ class test_chain(Case):
             self.assertListEqual(task.options['link_error'], [s('error')])
 
 
-class test_chord(Case):
+class test_chord(AppCase):
 
-    def setUp(self):
-        self.prev = app.tasks.get('celery.chord')
-        self.task = builtins.add_chord_task(app)()
+    def setup(self):
+        self.prev = self.app.tasks.get('celery.chord')
+        self.task = builtins.add_chord_task(self.app)()
 
-    def tearDown(self):
-        app.tasks['celery.chord'] = self.prev
+    def teardown(self):
+        self.app.tasks['celery.chord'] = self.prev
 
     def test_apply_async(self):
         x = chord([add.s(i, i) for i in range(10)], body=xsum.s())
@@ -213,11 +217,10 @@ class test_chord(Case):
         self.assertEqual(body.options['chord'], 'some_chord_id')
 
     def test_apply_eager(self):
-        app.conf.CELERY_ALWAYS_EAGER = True
+        self.app.conf.CELERY_ALWAYS_EAGER = True
         try:
             x = chord([add.s(i, i) for i in range(10)], body=xsum.s())
             r = x.apply_async()
             self.assertEqual(r.get(), 90)
-
         finally:
-            app.conf.CELERY_ALWAYS_EAGER = False
+            self.app.conf.CELERY_ALWAYS_EAGER = False

+ 5 - 8
celery/tests/app/test_loaders.py

@@ -7,7 +7,6 @@ import warnings
 from mock import Mock, patch
 
 from celery import loaders
-from celery.app import app_or_default
 from celery.exceptions import (
     NotConfigured,
     CPendingDeprecationWarning,
@@ -47,7 +46,7 @@ class test_loaders(AppCase):
             self.assertIs(loaders.load_settings(), self.app.conf)
 
 
-class test_LoaderBase(Case):
+class test_LoaderBase(AppCase):
     message_options = {'subject': 'Subject',
                        'body': 'Body',
                        'sender': 'x@x.com',
@@ -58,9 +57,8 @@ class test_LoaderBase(Case):
                       'password': 'qwerty',
                       'timeout': 3}
 
-    def setUp(self):
-        self.loader = DummyLoader()
-        self.app = app_or_default()
+    def setup(self):
+        self.loader = DummyLoader(app=self.app)
 
     def test_handlers_pass(self):
         self.loader.on_task_init('foo.task', 'feedface-cafebabe')
@@ -222,10 +220,9 @@ class test_DefaultLoader(Case):
         self.assertTrue(context_executed[0])
 
 
-class test_AppLoader(Case):
+class test_AppLoader(AppCase):
 
-    def setUp(self):
-        self.app = app_or_default()
+    def setup(self):
         self.loader = AppLoader(app=self.app)
 
     def test_on_worker_init(self):

+ 14 - 13
celery/tests/app/test_log.py

@@ -7,7 +7,6 @@ from tempfile import mktemp
 from mock import patch, Mock
 from nose import SkipTest
 
-from celery import current_app
 from celery import signals
 from celery.app.log import Logging, TaskFormatter
 from celery.utils.log import LoggingProxy
@@ -22,7 +21,6 @@ from celery.tests.utils import (
     AppCase, Case, override_stdouts, wrap_logger, get_handlers,
 )
 
-log = current_app.log
 
 
 class test_TaskFormatter(Case):
@@ -110,7 +108,7 @@ class test_ColorFormatter(Case):
 class test_default_logger(AppCase):
 
     def setup(self):
-        self.setup_logger = log.setup_logger
+        self.setup_logger = self.app.log.setup_logger
         self.get_logger = lambda n=None: get_logger(n) if n else logging.root
         signals.setup_logging.receivers[:] = []
         Logging._setup = False
@@ -124,12 +122,12 @@ class test_default_logger(AppCase):
         self.assertIs(logger.parent, logging.root)
 
     def test_setup_logging_subsystem_misc(self):
-        log.setup_logging_subsystem(loglevel=None)
+        self.app.log.setup_logging_subsystem(loglevel=None)
 
     def test_setup_logging_subsystem_misc2(self):
         self.app.conf.CELERYD_HIJACK_ROOT_LOGGER = True
         try:
-            log.setup_logging_subsystem()
+            self.app.log.setup_logging_subsystem()
         finally:
             self.app.conf.CELERYD_HIJACK_ROOT_LOGGER = False
 
@@ -142,14 +140,14 @@ class test_default_logger(AppCase):
         logger.handlers[:] = []
 
     def test_setup_logging_subsystem_colorize(self):
-        log.setup_logging_subsystem(colorize=None)
-        log.setup_logging_subsystem(colorize=True)
+        self.app.log.setup_logging_subsystem(colorize=None)
+        self.app.log.setup_logging_subsystem(colorize=True)
 
     def test_setup_logging_subsystem_no_mputil(self):
         from celery.utils import log as logtools
         mputil, logtools.mputil = logtools.mputil, None
         try:
-            log.setup_logging_subsystem()
+            self.app.log.setup_logging_subsystem()
         finally:
             logtools.mputil = mputil
 
@@ -203,11 +201,14 @@ class test_default_logger(AppCase):
                                    root=False)
         try:
             with wrap_logger(logger) as sio:
-                log.redirect_stdouts_to_logger(logger, loglevel=logging.ERROR)
+                self.app.log.redirect_stdouts_to_logger(
+                    logger, loglevel=logging.ERROR,
+                )
                 logger.error('foo')
                 self.assertIn('foo', sio.getvalue())
-                log.redirect_stdouts_to_logger(logger, stdout=False,
-                                               stderr=False)
+                self.app.log.redirect_stdouts_to_logger(
+                    logger, stdout=False, stderr=False,
+                )
         finally:
             sys.stdout, sys.stderr = sys.__stdout__, sys.__stderr__
 
@@ -250,7 +251,7 @@ class test_task_logger(test_default_logger):
         logging.root.manager.loggerDict.pop(logger.name, None)
         self.uid = uuid()
 
-        @current_app.task
+        @self.app.task
         def test_task():
             pass
         self.get_logger().handlers = []
@@ -263,7 +264,7 @@ class test_task_logger(test_default_logger):
         _task_stack.pop()
 
     def setup_logger(self, *args, **kwargs):
-        return log.setup_task_loggers(*args, **kwargs)
+        return self.app.log.setup_task_loggers(*args, **kwargs)
 
     def get_logger(self, *args, **kwargs):
         return get_task_logger("test_task_logger")

+ 101 - 100
celery/tests/app/test_routes.py

@@ -1,19 +1,18 @@
 from __future__ import absolute_import
 
-from functools import wraps
+from contextlib import contextmanager
 
 from kombu import Exchange
 from kombu.utils.functional import maybe_promise
 
-from celery import current_app
 from celery.app import routes
 from celery.exceptions import QueueNotFound
 from celery.task import task
-from celery.tests.utils import Case
+from celery.tests.utils import AppCase
 
 
-def Router(*args, **kwargs):
-    return routes.Router(*args, app=current_app, **kwargs)
+def Router(app, *args, **kwargs):
+    return routes.Router(*args, app=app, **kwargs)
 
 
 @task()
@@ -21,72 +20,70 @@ def mytask():
     pass
 
 
-def E(queues):
+def E(app, queues):
     def expand(answer):
-        return Router([], queues).expand_destination(answer)
+        return Router(app, [], queues).expand_destination(answer)
     return expand
 
 
-def with_queues(**queues):
-
-    def patch_fun(fun):
-
-        @wraps(fun)
-        def __inner(*args, **kwargs):
-            app = current_app
-            prev_queues = app.conf.CELERY_QUEUES
-            prev_Queues = app.amqp.queues
-            app.conf.CELERY_QUEUES = queues
-            app.amqp.queues = app.amqp.Queues(queues)
-            try:
-                return fun(*args, **kwargs)
-            finally:
-                app.conf.CELERY_QUEUES = prev_queues
-                app.amqp.queues = prev_Queues
-        return __inner
-    return patch_fun
-
-
-a_queue = {'exchange': 'fooexchange',
-           'exchange_type': 'fanout',
-           'routing_key': 'xuzzy'}
-b_queue = {'exchange': 'barexchange',
-           'exchange_type': 'topic',
-           'routing_key': 'b.b.#'}
-d_queue = {'exchange': current_app.conf.CELERY_DEFAULT_EXCHANGE,
-           'exchange_type': current_app.conf.CELERY_DEFAULT_EXCHANGE_TYPE,
-           'routing_key': current_app.conf.CELERY_DEFAULT_ROUTING_KEY}
-
-
-class RouteCase(Case):
-    pass
+@contextmanager
+def _queues(app, **queues):
+    prev_queues = app.conf.CELERY_QUEUES
+    prev_Queues = app.amqp.queues
+    app.conf.CELERY_QUEUES = queues
+    app.amqp.queues = app.amqp.Queues(queues)
+    try:
+        yield
+    finally:
+        app.conf.CELERY_QUEUES = prev_queues
+        app.amqp.queues = prev_Queues
+
+
+class RouteCase(AppCase):
+
+    def setup(self):
+        self.a_queue = {
+            'exchange': 'fooexchange',
+            'exchange_type': 'fanout',
+            'routing_key': 'xuzzy',
+        }
+        self.b_queue = {
+            'exchange': 'barexchange',
+            'exchange_type': 'topic',
+            'routing_key': 'b.b.#',
+        }
+        self.d_queue = {
+            'exchange': self.app.conf.CELERY_DEFAULT_EXCHANGE,
+            'exchange_type': self.app.conf.CELERY_DEFAULT_EXCHANGE_TYPE,
+            'routing_key': self.app.conf.CELERY_DEFAULT_ROUTING_KEY,
+        }
 
 
 class test_MapRoute(RouteCase):
 
-    @with_queues(foo=a_queue, bar=b_queue)
     def test_route_for_task_expanded_route(self):
-        expand = E(current_app.amqp.queues)
-        route = routes.MapRoute({mytask.name: {'queue': 'foo'}})
-        self.assertEqual(
-            expand(route.route_for_task(mytask.name))['queue'].name,
-            'foo',
-        )
-        self.assertIsNone(route.route_for_task('celery.awesome'))
-
-    @with_queues(foo=a_queue, bar=b_queue)
+        with _queues(self.app, foo=self.a_queue, bar=self.b_queue):
+            expand = E(self.app, self.app.amqp.queues)
+            route = routes.MapRoute({mytask.name: {'queue': 'foo'}})
+            self.assertEqual(
+                expand(route.route_for_task(mytask.name))['queue'].name,
+                'foo',
+            )
+            self.assertIsNone(route.route_for_task('celery.awesome'))
+
     def test_route_for_task(self):
-        expand = E(current_app.amqp.queues)
-        route = routes.MapRoute({mytask.name: b_queue})
-        self.assertDictContainsSubset(
-            b_queue,
-            expand(route.route_for_task(mytask.name)),
-        )
-        self.assertIsNone(route.route_for_task('celery.awesome'))
+        with _queues(self.app, foo=self.a_queue, bar=self.b_queue):
+            expand = E(self.app, self.app.amqp.queues)
+            route = routes.MapRoute({mytask.name: self.b_queue})
+            self.assertDictContainsSubset(
+                self.b_queue,
+                expand(route.route_for_task(mytask.name)),
+            )
+            self.assertIsNone(route.route_for_task('celery.awesome'))
 
     def test_expand_route_not_found(self):
-        expand = E(current_app.amqp.Queues(
-            current_app.conf.CELERY_QUEUES, False))
+        expand = E(self.app, self.app.amqp.Queues(
+                   self.app.conf.CELERY_QUEUES, False))
         route = routes.MapRoute({'a': {'queue': 'x'}})
         with self.assertRaises(QueueNotFound):
             expand(route.route_for_task('a'))
@@ -95,55 +92,59 @@ class test_MapRoute(RouteCase):
 class test_lookup_route(RouteCase):
 
     def test_init_queues(self):
-        router = Router(queues=None)
+        router = Router(self.app, queues=None)
         self.assertDictEqual(router.queues, {})
 
-    @with_queues(foo=a_queue, bar=b_queue)
     def test_lookup_takes_first(self):
-        R = routes.prepare(({mytask.name: {'queue': 'bar'}},
-                            {mytask.name: {'queue': 'foo'}}))
-        router = Router(R, current_app.amqp.queues)
-        self.assertEqual(router.route({}, mytask.name,
-                         args=[1, 2], kwargs={})['queue'].name, 'bar')
+        with _queues(self.app, foo=self.a_queue, bar=self.b_queue):
+            R = routes.prepare(({mytask.name: {'queue': 'bar'}},
+                                {mytask.name: {'queue': 'foo'}}))
+            router = Router(self.app, R, self.app.amqp.queues)
+            self.assertEqual(router.route({}, mytask.name,
+                             args=[1, 2], kwargs={})['queue'].name, 'bar')
 
-    @with_queues()
     def test_expands_queue_in_options(self):
-        R = routes.prepare(())
-        router = Router(R, current_app.amqp.queues, create_missing=True)
-        # apply_async forwards all arguments, even exchange=None etc,
-        # so need to make sure it's merged correctly.
-        route = router.route({'queue': 'testq',
-                              'exchange': None,
-                              'routing_key': None,
-                              'immediate': False},
-                             mytask.name,
-                             args=[1, 2], kwargs={})
-        self.assertEqual(route['queue'].name, 'testq')
-        self.assertEqual(route['queue'].exchange, Exchange('testq'))
-        self.assertEqual(route['queue'].routing_key, 'testq')
-        self.assertEqual(route['immediate'], False)
-
-    @with_queues(foo=a_queue, bar=b_queue)
+        with _queues(self.app):
+            R = routes.prepare(())
+            router = Router(self.app, R, self.app.amqp.queues, create_missing=True)
+            # apply_async forwards all arguments, even exchange=None etc,
+            # so need to make sure it's merged correctly.
+            route = router.route(
+                {'queue': 'testq',
+                 'exchange': None,
+                 'routing_key': None,
+                 'immediate': False},
+                mytask.name,
+                args=[1, 2], kwargs={},
+            )
+            self.assertEqual(route['queue'].name, 'testq')
+            self.assertEqual(route['queue'].exchange, Exchange('testq'))
+            self.assertEqual(route['queue'].routing_key, 'testq')
+            self.assertEqual(route['immediate'], False)
+
     def test_expand_destination_string(self):
-        x = Router({}, current_app.amqp.queues)
-        dest = x.expand_destination('foo')
-        self.assertEqual(dest['queue'].name, 'foo')
+        with _queues(self.app, foo=self.a_queue, bar=self.b_queue):
+            x = Router(self.app, {}, self.app.amqp.queues)
+            dest = x.expand_destination('foo')
+            self.assertEqual(dest['queue'].name, 'foo')
 
-    @with_queues(foo=a_queue, bar=b_queue, **{
-        current_app.conf.CELERY_DEFAULT_QUEUE: d_queue})
     def test_lookup_paths_traversed(self):
-        R = routes.prepare(({'celery.xaza': {'queue': 'bar'}},
-                            {mytask.name: {'queue': 'foo'}}))
-        router = Router(R, current_app.amqp.queues)
-        self.assertEqual(router.route({}, mytask.name,
-                         args=[1, 2], kwargs={})['queue'].name, 'foo')
-        self.assertEqual(
-            router.route({}, 'celery.poza')['queue'].name,
-            current_app.conf.CELERY_DEFAULT_QUEUE,
-        )
-
-
-class test_prepare(Case):
+        with _queues(self.app, foo=self.a_queue, bar=self.b_queue, **{
+                self.app.conf.CELERY_DEFAULT_QUEUE: self.d_queue}):
+            R = routes.prepare((
+                {'celery.xaza': {'queue': 'bar'}},
+                {mytask.name: {'queue': 'foo'}}
+            ))
+            router = Router(self.app, R, self.app.amqp.queues)
+            self.assertEqual(router.route({}, mytask.name,
+                             args=[1, 2], kwargs={})['queue'].name, 'foo')
+            self.assertEqual(
+                router.route({}, 'celery.poza')['queue'].name,
+                self.app.conf.CELERY_DEFAULT_QUEUE,
+            )
+
+
+class test_prepare(AppCase):
 
     def test_prepare(self):
         from celery.datastructures import LRUCache

+ 1 - 2
celery/tests/utilities/test_platforms.py

@@ -7,7 +7,6 @@ import signal
 
 from mock import Mock, patch
 
-from celery import current_app
 from celery import platforms
 from celery.five import open_fqdn
 from celery.platforms import (
@@ -97,7 +96,7 @@ class test_Signals(Case):
         signals['INT'] = lambda *a: a
 
 
-if not current_app.IS_WINDOWS:
+if not platforms.IS_WINDOWS:
 
     class test_get_fdmax(Case):
 

+ 95 - 33
celery/tests/worker/test_control.py

@@ -9,7 +9,6 @@ from datetime import datetime, timedelta
 from kombu import pidbox
 from mock import Mock, patch, call
 
-from celery import current_app
 from celery.datastructures import AttributeDict
 from celery.task import task
 from celery.utils import uuid
@@ -22,7 +21,8 @@ from celery.five import Queue as FastQueue
 from celery.worker.job import TaskRequest
 from celery.worker.state import revoked
 from celery.worker.control import Panel
-from celery.tests.utils import Case
+from celery.worker.pidbox import Pidbox, gPidbox
+from celery.tests.utils import AppCase, Case
 
 hostname = socket.gethostname()
 
@@ -41,11 +41,11 @@ class WorkController(object):
 
 class Consumer(consumer.Consumer):
 
-    def __init__(self):
+    def __init__(self, app):
+        self.app = app
         self.buffer = FastQueue()
         self.handle_task = self.buffer.put
         self.timer = Timer()
-        self.app = current_app
         self.event_dispatcher = Mock()
         self.controller = WorkController()
         self.task_consumer = Mock()
@@ -55,11 +55,73 @@ class Consumer(consumer.Consumer):
         self.task_buckets = defaultdict(lambda: None)
 
 
-class test_ControlPanel(Case):
+class test_Pidbox(AppCase):
 
-    def setUp(self):
-        self.app = current_app
-        self.panel = self.create_panel(consumer=Consumer())
+    def test_shutdown(self):
+        with patch('celery.worker.pidbox.ignore_errors') as eig:
+            parent = Mock()
+            pidbox = Pidbox(parent)
+            pidbox._close_channel = Mock()
+            self.assertIs(pidbox.c, parent)
+            pconsumer = pidbox.consumer = Mock()
+            cancel = pconsumer.cancel
+            pidbox.shutdown(parent)
+            eig.assert_called_with(parent, cancel)
+            pidbox._close_channel.assert_called_with(parent)
+
+
+class test_Pidbox_green(AppCase):
+
+    def test_stop(self):
+        parent = Mock()
+        g = gPidbox(parent)
+        stopped = g._node_stopped = Mock()
+        shutdown = g._node_shutdown = Mock()
+        close_chan = g._close_channel = Mock()
+
+        g.stop(parent)
+        shutdown.set.assert_called_with()
+        stopped.wait.assert_called_with()
+        close_chan.assert_called_with(parent)
+        self.assertIsNone(g._node_stopped)
+        self.assertIsNone(g._node_shutdown)
+
+        close_chan.reset()
+        g.stop(parent)
+        close_chan.assert_called_with(parent)
+
+    def test_resets(self):
+        parent = Mock()
+        g = gPidbox(parent)
+        g._resets = 100
+        g.reset()
+        self.assertEqual(g._resets, 101)
+
+    def test_loop(self):
+        parent = Mock()
+        conn = parent.connect.return_value = self.app.connection()
+        drain = conn.drain_events = Mock()
+        g = gPidbox(parent)
+        parent.connection = Mock()
+        do_reset = g._do_reset = Mock()
+
+        call_count = [0]
+
+        def se(*args, **kwargs):
+            if call_count[0] > 2:
+                g._node_shutdown.set()
+            g.reset()
+            call_count[0] += 1
+        drain.side_effect = se
+        g.loop(parent)
+
+        self.assertEqual(do_reset.call_count, 4)
+
+
+class test_ControlPanel(AppCase):
+
+    def setup(self):
+        self.panel = self.create_panel(consumer=Consumer(self.app))
 
     def create_state(self, **kwargs):
         kwargs.setdefault('app', self.app)
@@ -71,7 +133,7 @@ class test_ControlPanel(Case):
                                              handlers=Panel.data)
 
     def test_enable_events(self):
-        consumer = Consumer()
+        consumer = Consumer(self.app)
         panel = self.create_panel(consumer=consumer)
         evd = consumer.event_dispatcher
         evd.groups = set()
@@ -81,7 +143,7 @@ class test_ControlPanel(Case):
         self.assertIn('already enabled', panel.handle('enable_events')['ok'])
 
     def test_disable_events(self):
-        consumer = Consumer()
+        consumer = Consumer(self.app)
         panel = self.create_panel(consumer=consumer)
         evd = consumer.event_dispatcher
         evd.enabled = True
@@ -91,7 +153,7 @@ class test_ControlPanel(Case):
         self.assertIn('already disabled', panel.handle('disable_events')['ok'])
 
     def test_heartbeat(self):
-        consumer = Consumer()
+        consumer = Consumer(self.app)
         panel = self.create_panel(consumer=consumer)
         consumer.event_dispatcher.enabled = True
         panel.handle('heartbeat')
@@ -122,7 +184,7 @@ class test_ControlPanel(Case):
     def test_active_queues(self):
         import kombu
 
-        x = kombu.Consumer(current_app.connection(),
+        x = kombu.Consumer(self.app.connection(),
                            [kombu.Queue('foo', kombu.Exchange('foo'), 'foo'),
                             kombu.Queue('bar', kombu.Exchange('bar'), 'bar')],
                            auto_declare=False)
@@ -170,7 +232,7 @@ class test_ControlPanel(Case):
             def shrink(self, n=1):
                 self.size -= n
 
-        consumer = Consumer()
+        consumer = Consumer(self.app)
         consumer.pool = MockPool()
         panel = self.create_panel(consumer=consumer)
 
@@ -206,7 +268,7 @@ class test_ControlPanel(Case):
             def consuming_from(self, queue):
                 return queue in self.queues
 
-        consumer = Consumer()
+        consumer = Consumer(self.app)
         consumer.task_consumer = MockConsumer()
         panel = self.create_panel(consumer=consumer)
 
@@ -229,7 +291,7 @@ class test_ControlPanel(Case):
             state.revoked.clear()
 
     def test_dump_schedule(self):
-        consumer = Consumer()
+        consumer = Consumer(self.app)
         panel = self.create_panel(consumer=consumer)
         self.assertFalse(panel.handle('dump_schedule'))
         r = TaskRequest(mytask.name, 'CAFEBABE', (), {})
@@ -240,7 +302,7 @@ class test_ControlPanel(Case):
 
     def test_dump_reserved(self):
         from celery.worker import state
-        consumer = Consumer()
+        consumer = Consumer(self.app)
         state.reserved_requests.add(
             TaskRequest(mytask.name, uuid(), args=(2, 2), kwargs={}),
         )
@@ -266,16 +328,16 @@ class test_ControlPanel(Case):
 
     def test_rate_limit(self):
 
-        class Consumer(object):
+        class xConsumer(object):
             reset = False
 
             def reset_rate_limits(self):
                 self.reset = True
 
-        consumer = Consumer()
-        panel = self.create_panel(app=current_app, consumer=consumer)
+        consumer = xConsumer()
+        panel = self.create_panel(app=self.app, consumer=consumer)
 
-        task = current_app.tasks[mytask.name]
+        task = self.app.tasks[mytask.name]
         old_rate_limit = task.rate_limit
         try:
             panel.handle('rate_limit', arguments=dict(task_name=task.name,
@@ -383,7 +445,7 @@ class test_ControlPanel(Case):
                 replies.append(data)
 
         panel = _Node(hostname=hostname,
-                      state=self.create_state(consumer=Consumer()),
+                      state=self.create_state(consumer=Consumer(self.app)),
                       handlers=Panel.data,
                       mailbox=self.app.control.mailbox)
         r = panel.dispatch('ping', reply_to={'exchange': 'x',
@@ -392,8 +454,8 @@ class test_ControlPanel(Case):
         self.assertDictEqual(replies[0], {panel.hostname: {'ok': 'pong'}})
 
     def test_pool_restart(self):
-        consumer = Consumer()
-        consumer.controller = _WC(app=current_app)
+        consumer = Consumer(self.app)
+        consumer.controller = _WC(app=self.app)
         consumer.controller.pool.restart = Mock()
         panel = self.create_panel(consumer=consumer)
         panel.app = self.app
@@ -403,25 +465,25 @@ class test_ControlPanel(Case):
         with self.assertRaises(ValueError):
             panel.handle('pool_restart', {'reloader': _reload})
 
-        current_app.conf.CELERYD_POOL_RESTARTS = True
+        self.app.conf.CELERYD_POOL_RESTARTS = True
         try:
             panel.handle('pool_restart', {'reloader': _reload})
             self.assertTrue(consumer.controller.pool.restart.called)
             self.assertFalse(_reload.called)
             self.assertFalse(_import.called)
         finally:
-            current_app.conf.CELERYD_POOL_RESTARTS = False
+            self.app.conf.CELERYD_POOL_RESTARTS = False
 
     def test_pool_restart_import_modules(self):
-        consumer = Consumer()
-        consumer.controller = _WC(app=current_app)
+        consumer = Consumer(self.app)
+        consumer.controller = _WC(app=self.app)
         consumer.controller.pool.restart = Mock()
         panel = self.create_panel(consumer=consumer)
         panel.app = self.app
         _import = consumer.controller.app.loader.import_from_cwd = Mock()
         _reload = Mock()
 
-        current_app.conf.CELERYD_POOL_RESTARTS = True
+        self.app.conf.CELERYD_POOL_RESTARTS = True
         try:
             panel.handle('pool_restart', {'modules': ['foo', 'bar'],
                                           'reloader': _reload})
@@ -433,18 +495,18 @@ class test_ControlPanel(Case):
                 _import.call_args_list,
             )
         finally:
-            current_app.conf.CELERYD_POOL_RESTARTS = False
+            self.app.conf.CELERYD_POOL_RESTARTS = False
 
     def test_pool_restart_reload_modules(self):
-        consumer = Consumer()
-        consumer.controller = _WC(app=current_app)
+        consumer = Consumer(self.app)
+        consumer.controller = _WC(app=self.app)
         consumer.controller.pool.restart = Mock()
         panel = self.create_panel(consumer=consumer)
         panel.app = self.app
         _import = panel.app.loader.import_from_cwd = Mock()
         _reload = Mock()
 
-        current_app.conf.CELERYD_POOL_RESTARTS = True
+        self.app.conf.CELERYD_POOL_RESTARTS = True
         try:
             with patch.dict(sys.modules, {'foo': None}):
                 panel.handle('pool_restart', {'modules': ['foo'],
@@ -467,4 +529,4 @@ class test_ControlPanel(Case):
                 self.assertTrue(_reload.called)
                 self.assertFalse(_import.called)
         finally:
-            current_app.conf.CELERYD_POOL_RESTARTS = False
+            self.app.conf.CELERYD_POOL_RESTARTS = False

+ 15 - 17
celery/tests/worker/test_request.py

@@ -14,9 +14,7 @@ from kombu.utils.encoding import from_utf8, default_encode
 from mock import Mock, patch
 from nose import SkipTest
 
-from celery import current_app
 from celery import states
-from celery.app import app_or_default
 from celery.concurrency.base import BasePool
 from celery.datastructures import ExceptionInfo
 from celery.exceptions import (
@@ -79,9 +77,9 @@ class test_mro_lookup(Case):
         self.assertIsNone(mro_lookup(D, 'x'))
 
 
-def jail(task_id, name, args, kwargs):
+def jail(app, task_id, name, args, kwargs):
     request = {'id': task_id}
-    task = current_app.tasks[name]
+    task = app.tasks[name]
     task.__trace__ = None  # rebuild
     return trace_task(
         task, task_id, args, kwargs, request=request, eager=False,
@@ -120,9 +118,9 @@ def mytask_raising(i):
     raise KeyError(i)
 
 
-class test_default_encode(Case):
+class test_default_encode(AppCase):
 
-    def setUp(self):
+    def setup(self):
         if sys.version_info >= (3, 0):
             raise SkipTest('py3k: not relevant')
 
@@ -146,7 +144,7 @@ class test_default_encode(Case):
             sys.getfilesystemencoding = gfe
 
 
-class test_RetryTaskError(Case):
+class test_RetryTaskError(AppCase):
 
     def test_retry_task_error(self):
         try:
@@ -156,7 +154,7 @@ class test_RetryTaskError(Case):
             self.assertEqual(ret.exc, exc)
 
 
-class test_trace_task(Case):
+class test_trace_task(AppCase):
 
     @patch('celery.task.trace._logger')
     def test_process_cleanup_fails(self, _logger):
@@ -165,7 +163,7 @@ class test_trace_task(Case):
         mytask.backend.process_cleanup = Mock(side_effect=KeyError())
         try:
             tid = uuid()
-            ret = jail(tid, mytask.name, [2], {})
+            ret = jail(self.app, tid, mytask.name, [2], {})
             self.assertEqual(ret, 4)
             mytask.backend.store_result.assert_called_with(tid, 4,
                                                            states.SUCCESS)
@@ -180,12 +178,12 @@ class test_trace_task(Case):
         mytask.backend.process_cleanup = Mock(side_effect=SystemExit())
         try:
             with self.assertRaises(SystemExit):
-                jail(uuid(), mytask.name, [2], {})
+                jail(self.app, uuid(), mytask.name, [2], {})
         finally:
             mytask.backend = backend
 
     def test_execute_jail_success(self):
-        ret = jail(uuid(), mytask.name, [2], {})
+        ret = jail(self.app, uuid(), mytask.name, [2], {})
         self.assertEqual(ret, 4)
 
     def test_marked_as_started(self):
@@ -202,12 +200,12 @@ class test_trace_task(Case):
 
         try:
             tid = uuid()
-            jail(tid, mytask.name, [2], {})
+            jail(self.app, tid, mytask.name, [2], {})
             self.assertIn(tid, Backend._started)
 
             mytask.ignore_result = True
             tid = uuid()
-            jail(tid, mytask.name, [2], {})
+            jail(self.app, tid, mytask.name, [2], {})
             self.assertNotIn(tid, Backend._started)
         finally:
             mytask.backend = prev
@@ -215,14 +213,14 @@ class test_trace_task(Case):
             mytask.ignore_result = False
 
     def test_execute_jail_failure(self):
-        ret = jail(uuid(), mytask_raising.name,
+        ret = jail(self.app, uuid(), mytask_raising.name,
                    [4], {})
         self.assertIsInstance(ret, ExceptionInfo)
         self.assertTupleEqual(ret.exception.args, (4, ))
 
     def test_execute_ignore_result(self):
         task_id = uuid()
-        ret = jail(task_id, MyTaskIgnoreResult.name, [4], {})
+        ret = jail(self.app, task_id, MyTaskIgnoreResult.name, [4], {})
         self.assertEqual(ret, 256)
         self.assertFalse(AsyncResult(task_id).ready())
 
@@ -355,7 +353,7 @@ class test_TaskRequest(AppCase):
             mytask.ignore_result = False
 
     def test_send_email(self):
-        app = app_or_default()
+        app = self.app
         old_mail_admins = app.mail_admins
         old_enable_mails = mytask.send_error_emails
         mail_sent = [False]
@@ -814,7 +812,7 @@ class test_TaskRequest(AppCase):
 
     @patch('celery.worker.job.logger')
     def _test_on_failure(self, exception, logger):
-        app = app_or_default()
+        app = self.app
         tid = uuid()
         tw = TaskRequest(mytask.name, tid, [4], {'f': 'x'})
         try:

+ 6 - 7
celery/tests/worker/test_worker.py

@@ -14,7 +14,6 @@ from kombu.exceptions import StdChannelError
 from kombu.transport.base import Message
 from mock import call, Mock, patch
 
-from celery import current_app
 from celery.app.defaults import DEFAULTS
 from celery.bootsteps import RUN, CLOSE, TERMINATE, StartStopStep
 from celery.concurrency.base import BasePool
@@ -233,13 +232,13 @@ class test_QoS(Case):
         qos.set(qos.prev)
 
 
-class test_Consumer(Case):
+class test_Consumer(AppCase):
 
-    def setUp(self):
+    def setup(self):
         self.buffer = FastQueue()
         self.timer = Timer()
 
-    def tearDown(self):
+    def teardown(self):
         self.timer.stop()
 
     def test_info(self):
@@ -432,7 +431,7 @@ class test_Consumer(Case):
 
     def test_loop_ignores_socket_timeout(self):
 
-        class Connection(current_app.connection().__class__):
+        class Connection(self.app.connection().__class__):
             obj = None
 
             def drain_events(self, **kwargs):
@@ -448,7 +447,7 @@ class test_Consumer(Case):
 
     def test_loop_when_socket_error(self):
 
-        class Connection(current_app.connection().__class__):
+        class Connection(self.app.connection().__class__):
             obj = None
 
             def drain_events(self, **kwargs):
@@ -470,7 +469,7 @@ class test_Consumer(Case):
 
     def test_loop(self):
 
-        class Connection(current_app.connection().__class__):
+        class Connection(self.app.connection().__class__):
             obj = None
 
             def drain_events(self, **kwargs):