浏览代码

[tests] Now depends on case

Ask Solem 9 年之前
父节点
当前提交
d31097ec90
共有 60 个文件被更改,包括 508 次插入1359 次删除
  1. 3 0
      celery/__init__.py
  2. 6 4
      celery/contrib/migrate.py
  3. 0 799
      celery/tests/_case.py
  4. 13 17
      celery/tests/app/test_app.py
  5. 2 2
      celery/tests/app/test_beat.py
  6. 9 9
      celery/tests/app/test_defaults.py
  7. 2 2
      celery/tests/app/test_loaders.py
  8. 86 88
      celery/tests/app/test_log.py
  9. 3 3
      celery/tests/app/test_schedules.py
  10. 2 4
      celery/tests/backends/test_amqp.py
  11. 3 4
      celery/tests/backends/test_base.py
  12. 15 17
      celery/tests/backends/test_cache.py
  13. 130 140
      celery/tests/backends/test_cassandra.py
  14. 2 4
      celery/tests/backends/test_couchbase.py
  15. 2 4
      celery/tests/backends/test_couchdb.py
  16. 6 12
      celery/tests/backends/test_database.py
  17. 2 2
      celery/tests/backends/test_elasticsearch.py
  18. 2 2
      celery/tests/backends/test_filesystem.py
  19. 6 7
      celery/tests/backends/test_mongodb.py
  20. 4 4
      celery/tests/backends/test_redis.py
  21. 2 5
      celery/tests/backends/test_riak.py
  22. 9 9
      celery/tests/bin/test_base.py
  23. 26 24
      celery/tests/bin/test_beat.py
  24. 2 2
      celery/tests/bin/test_celeryd_detach.py
  25. 2 2
      celery/tests/bin/test_events.py
  26. 2 2
      celery/tests/bin/test_multi.py
  27. 49 59
      celery/tests/bin/test_worker.py
  28. 9 5
      celery/tests/case.py
  29. 2 2
      celery/tests/concurrency/test_eventlet.py
  30. 2 2
      celery/tests/concurrency/test_gevent.py
  31. 2 2
      celery/tests/concurrency/test_pool.py
  32. 5 7
      celery/tests/concurrency/test_prefork.py
  33. 6 6
      celery/tests/concurrency/test_threads.py
  34. 5 5
      celery/tests/contrib/test_migrate.py
  35. 3 3
      celery/tests/contrib/test_rdb.py
  36. 2 2
      celery/tests/events/test_cursesmon.py
  37. 11 10
      celery/tests/events/test_snapshot.py
  38. 2 2
      celery/tests/events/test_state.py
  39. 11 13
      celery/tests/fixups/test_django.py
  40. 2 2
      celery/tests/security/case.py
  41. 3 3
      celery/tests/security/test_certificate.py
  42. 2 2
      celery/tests/security/test_security.py
  43. 0 0
      celery/tests/slow/__init__.py
  44. 2 2
      celery/tests/utils/test_datastructures.py
  45. 21 25
      celery/tests/utils/test_platforms.py
  46. 2 2
      celery/tests/utils/test_serialization.py
  47. 3 3
      celery/tests/utils/test_sysinfo.py
  48. 2 2
      celery/tests/utils/test_term.py
  49. 2 2
      celery/tests/utils/test_threads.py
  50. 1 1
      celery/tests/utils/test_timer2.py
  51. 3 3
      celery/tests/worker/test_autoreload.py
  52. 4 2
      celery/tests/worker/test_autoscale.py
  53. 2 2
      celery/tests/worker/test_components.py
  54. 2 4
      celery/tests/worker/test_consumer.py
  55. 2 2
      celery/tests/worker/test_request.py
  56. 2 2
      celery/tests/worker/test_worker.py
  57. 1 3
      requirements/test.txt
  58. 0 1
      requirements/test3.txt
  59. 1 5
      setup.py
  60. 1 5
      tox.ini

+ 3 - 0
celery/__init__.py

@@ -158,4 +158,7 @@ old_module, new_module = five.recreate_module(  # pragma: no cover
     version_info_t=version_info_t,
     version_info_t=version_info_t,
     maybe_patch_concurrency=maybe_patch_concurrency,
     maybe_patch_concurrency=maybe_patch_concurrency,
     _find_option_with_arg=_find_option_with_arg,
     _find_option_with_arg=_find_option_with_arg,
+    absolute_import=absolute_import,
+    unicode_literals=unicode_literals,
+    print_function=print_function,
 )
 )

+ 6 - 4
celery/contrib/migrate.py

@@ -21,10 +21,12 @@ from celery.app import app_or_default
 from celery.five import string, string_t
 from celery.five import string, string_t
 from celery.utils import worker_direct
 from celery.utils import worker_direct
 
 
-__all__ = ['StopFiltering', 'State', 'republish', 'migrate_task',
-           'migrate_tasks', 'move', 'task_id_eq', 'task_id_in',
-           'start_filter', 'move_task_by_id', 'move_by_idmap',
-           'move_by_taskmap', 'move_direct', 'move_direct_by_id']
+__all__ = [
+    'StopFiltering', 'State', 'republish', 'migrate_task',
+    'migrate_tasks', 'move', 'task_id_eq', 'task_id_in',
+    'start_filter', 'move_task_by_id', 'move_by_idmap',
+    'move_by_taskmap', 'move_direct', 'move_direct_by_id',
+]
 
 
 MOVING_PROGRESS_FMT = """\
 MOVING_PROGRESS_FMT = """\
 Moving task {state.filtered}/{state.strtotal}: \
 Moving task {state.filtered}/{state.strtotal}: \

+ 0 - 799
celery/tests/_case.py

@@ -1,799 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-import importlib
-import inspect
-import io
-import logging
-import os
-import platform
-import re
-import sys
-import time
-import types
-import warnings
-
-from contextlib import contextmanager
-from functools import partial, wraps
-from six import (
-    iteritems as items,
-    itervalues as values,
-    string_types,
-    reraise,
-)
-from six.moves import builtins
-
-from nose import SkipTest
-
-try:
-    import unittest  # noqa
-    unittest.skip
-    from unittest.util import safe_repr, unorderable_list_difference
-except AttributeError:
-    import unittest2 as unittest  # noqa
-    from unittest2.util import safe_repr, unorderable_list_difference  # noqa
-
-try:
-    from unittest import mock
-except ImportError:
-    import mock  # noqa
-
-__all__ = [
-    'ANY', 'Case', 'ContextMock', 'MagicMock', 'Mock', 'MockCallbacks',
-    'call', 'patch', 'sentinel',
-
-    'mock_open', 'mock_context', 'mock_module',
-    'patch_modules', 'reset_modules', 'sys_platform', 'pypy_version',
-    'platform_pyimp', 'replace_module_value', 'override_stdouts',
-    'mask_modules', 'sleepdeprived', 'mock_environ', 'wrap_logger',
-    'restore_logging',
-
-    'todo', 'skip', 'skip_if_darwin', 'skip_if_environ',
-    'skip_if_jython', 'skip_if_platform', 'skip_if_pypy', 'skip_if_python3',
-    'skip_if_win32', 'skip_unless_module', 'skip_unless_symbol',
-]
-
-patch = mock.patch
-call = mock.call
-sentinel = mock.sentinel
-MagicMock = mock.MagicMock
-ANY = mock.ANY
-
-PY3 = sys.version_info[0] == 3
-if PY3:
-    open_fqdn = 'builtins.open'
-    module_name_t = str
-else:
-    open_fqdn = '__builtin__.open'  # noqa
-    module_name_t = bytes  # noqa
-
-StringIO = io.StringIO
-_SIO_write = StringIO.write
-_SIO_init = StringIO.__init__
-
-
-def symbol_by_name(name, aliases={}, imp=None, package=None,
-                   sep='.', default=None, **kwargs):
-    """Get symbol by qualified name.
-
-    The name should be the full dot-separated path to the class::
-
-        modulename.ClassName
-
-    Example::
-
-        celery.concurrency.processes.TaskPool
-                                    ^- class name
-
-    or using ':' to separate module and symbol::
-
-        celery.concurrency.processes:TaskPool
-
-    If `aliases` is provided, a dict containing short name/long name
-    mappings, the name is looked up in the aliases first.
-
-    Examples:
-
-        >>> symbol_by_name('celery.concurrency.processes.TaskPool')
-        <class 'celery.concurrency.processes.TaskPool'>
-
-        >>> symbol_by_name('default', {
-        ...     'default': 'celery.concurrency.processes.TaskPool'})
-        <class 'celery.concurrency.processes.TaskPool'>
-
-        # Does not try to look up non-string names.
-        >>> from celery.concurrency.processes import TaskPool
-        >>> symbol_by_name(TaskPool) is TaskPool
-        True
-
-    """
-    if imp is None:
-        imp = importlib.import_module
-
-    if not isinstance(name, string_types):
-        return name                                 # already a class
-
-    name = aliases.get(name) or name
-    sep = ':' if ':' in name else sep
-    module_name, _, cls_name = name.rpartition(sep)
-    if not module_name:
-        cls_name, module_name = None, package if package else cls_name
-    try:
-        try:
-            module = imp(module_name, package=package, **kwargs)
-        except ValueError as exc:
-            reraise(ValueError,
-                    ValueError("Couldn't import {0!r}: {1}".format(name, exc)),
-                    sys.exc_info()[2])
-        return getattr(module, cls_name) if cls_name else module
-    except (ImportError, AttributeError):
-        if default is None:
-            raise
-    return default
-
-
-class WhateverIO(StringIO):
-
-    def __init__(self, v=None, *a, **kw):
-        _SIO_init(self, v.decode() if isinstance(v, bytes) else v, *a, **kw)
-
-    def write(self, data):
-        _SIO_write(self, data.decode() if isinstance(data, bytes) else data)
-
-
-def noop(*args, **kwargs):
-    pass
-
-
-class Mock(mock.Mock):
-
-    def __init__(self, *args, **kwargs):
-        attrs = kwargs.pop('attrs', None) or {}
-        super(Mock, self).__init__(*args, **kwargs)
-        for attr_name, attr_value in items(attrs):
-            setattr(self, attr_name, attr_value)
-
-
-class _ContextMock(Mock):
-    """Dummy class implementing __enter__ and __exit__
-    as the :keyword:`with` statement requires these to be implemented
-    in the class, not just the instance."""
-
-    def __enter__(self):
-        return self
-
-    def __exit__(self, *exc_info):
-        pass
-
-
-def ContextMock(*args, **kwargs):
-    obj = _ContextMock(*args, **kwargs)
-    obj.attach_mock(_ContextMock(), '__enter__')
-    obj.attach_mock(_ContextMock(), '__exit__')
-    obj.__enter__.return_value = obj
-    # if __exit__ return a value the exception is ignored,
-    # so it must return None here.
-    obj.__exit__.return_value = None
-    return obj
-
-
-def _bind(f, o):
-    @wraps(f)
-    def bound_meth(*fargs, **fkwargs):
-        return f(o, *fargs, **fkwargs)
-    return bound_meth
-
-
-if PY3:  # pragma: no cover
-    def _get_class_fun(meth):
-        return meth
-else:
-    def _get_class_fun(meth):
-        return meth.__func__
-
-
-class MockCallbacks(object):
-
-    def __new__(cls, *args, **kwargs):
-        r = Mock(name=cls.__name__)
-        _get_class_fun(cls.__init__)(r, *args, **kwargs)
-        for key, value in items(vars(cls)):
-            if key not in ('__dict__', '__weakref__', '__new__', '__init__'):
-                if inspect.ismethod(value) or inspect.isfunction(value):
-                    r.__getattr__(key).side_effect = _bind(value, r)
-                else:
-                    r.__setattr__(key, value)
-        return r
-
-
-# -- adds assertWarns from recent unittest2, not in Python 2.7.
-
-class _AssertRaisesBaseContext(object):
-
-    def __init__(self, expected, test_case, callable_obj=None,
-                 expected_regex=None):
-        self.expected = expected
-        self.failureException = test_case.failureException
-        self.obj_name = None
-        if isinstance(expected_regex, string_types):
-            expected_regex = re.compile(expected_regex)
-        self.expected_regex = expected_regex
-
-
-def _is_magic_module(m):
-    # some libraries create custom module types that are lazily
-    # lodaded, e.g. Django installs some modules in sys.modules that
-    # will load _tkinter and other shit when touched.
-
-    # pyflakes refuses to accept 'noqa' for this isinstance.
-    cls, modtype = type(m), types.ModuleType
-    try:
-        variables = vars(cls)
-    except TypeError:
-        return True
-    else:
-        return (cls is not modtype and (
-            '__getattr__' in variables or
-            '__getattribute__' in variables))
-
-
-class _AssertWarnsContext(_AssertRaisesBaseContext):
-    """A context manager used to implement TestCase.assertWarns* methods."""
-
-    def __enter__(self):
-        # The __warningregistry__'s need to be in a pristine state for tests
-        # to work properly.
-        warnings.resetwarnings()
-        for v in list(values(sys.modules)):
-            # do not evaluate Django moved modules and other lazily
-            # initialized modules.
-            if v and not _is_magic_module(v):
-                # use raw __getattribute__ to protect even better from
-                # lazily loaded modules
-                try:
-                    object.__getattribute__(v, '__warningregistry__')
-                except AttributeError:
-                    pass
-                else:
-                    object.__setattr__(v, '__warningregistry__', {})
-        self.warnings_manager = warnings.catch_warnings(record=True)
-        self.warnings = self.warnings_manager.__enter__()
-        warnings.simplefilter('always', self.expected)
-        return self
-
-    def __exit__(self, exc_type, exc_value, tb):
-        self.warnings_manager.__exit__(exc_type, exc_value, tb)
-        if exc_type is not None:
-            # let unexpected exceptions pass through
-            return
-        try:
-            exc_name = self.expected.__name__
-        except AttributeError:
-            exc_name = str(self.expected)
-        first_matching = None
-        for m in self.warnings:
-            w = m.message
-            if not isinstance(w, self.expected):
-                continue
-            if first_matching is None:
-                first_matching = w
-            if (self.expected_regex is not None and
-                    not self.expected_regex.search(str(w))):
-                continue
-            # store warning for later retrieval
-            self.warning = w
-            self.filename = m.filename
-            self.lineno = m.lineno
-            return
-        # Now we simply try to choose a helpful failure message
-        if first_matching is not None:
-            raise self.failureException(
-                '%r does not match %r' % (
-                    self.expected_regex.pattern, str(first_matching)))
-        if self.obj_name:
-            raise self.failureException(
-                '%s not triggered by %s' % (exc_name, self.obj_name))
-        else:
-            raise self.failureException('%s not triggered' % exc_name)
-
-
-class Case(unittest.TestCase):
-    DeprecationWarning = DeprecationWarning
-    PendingDeprecationWarning = PendingDeprecationWarning
-
-    def patch(self, *path, **options):
-        manager = patch('.'.join(path), **options)
-        patched = manager.start()
-        self.addCleanup(manager.stop)
-        return patched
-
-    def mock_modules(self, *mods):
-        modules = []
-        for mod in mods:
-            mod = mod.split('.')
-            modules.extend(reversed([
-                '.'.join(mod[:-i] if i else mod) for i in range(len(mod))
-            ]))
-        modules = sorted(set(modules))
-        return self.wrap_context(mock_module(*modules))
-
-    def on_nth_call_do(self, mock, side_effect, n=1):
-
-        def on_call(*args, **kwargs):
-            if mock.call_count >= n:
-                mock.side_effect = side_effect
-            return mock.return_value
-        mock.side_effect = on_call
-        return mock
-
-    def on_nth_call_return(self, mock, retval, n=1):
-
-        def on_call(*args, **kwargs):
-            if mock.call_count >= n:
-                mock.return_value = retval
-            return mock.return_value
-        mock.side_effect = on_call
-        return mock
-
-    def mask_modules(self, *modules):
-        self.wrap_context(mask_modules(*modules))
-
-    def wrap_context(self, context):
-        ret = context.__enter__()
-        self.addCleanup(partial(context.__exit__, None, None, None))
-        return ret
-
-    def mock_environ(self, env_name, env_value):
-        return self.wrap_context(mock_environ(env_name, env_value))
-
-    def assertWarns(self, expected_warning):
-        return _AssertWarnsContext(expected_warning, self, None)
-
-    def assertWarnsRegex(self, expected_warning, expected_regex):
-        return _AssertWarnsContext(expected_warning, self,
-                                   None, expected_regex)
-
-    @contextmanager
-    def assertDeprecated(self):
-        with self.assertWarnsRegex(self.DeprecationWarning,
-                                   r'scheduled for removal'):
-            yield
-
-    @contextmanager
-    def assertPendingDeprecation(self):
-        with self.assertWarnsRegex(self.PendingDeprecationWarning,
-                                   r'scheduled for deprecation'):
-            yield
-
-    def assertDictContainsSubset(self, expected, actual, msg=None):
-        missing, mismatched = [], []
-
-        for key, value in items(expected):
-            if key not in actual:
-                missing.append(key)
-            elif value != actual[key]:
-                mismatched.append('%s, expected: %s, actual: %s' % (
-                    safe_repr(key), safe_repr(value),
-                    safe_repr(actual[key])))
-
-        if not (missing or mismatched):
-            return
-
-        standard_msg = ''
-        if missing:
-            standard_msg = 'Missing: %s' % ','.join(map(safe_repr, missing))
-
-        if mismatched:
-            if standard_msg:
-                standard_msg += '; '
-            standard_msg += 'Mismatched values: %s' % (
-                ','.join(mismatched))
-
-        self.fail(self._formatMessage(msg, standard_msg))
-
-    def assertItemsEqual(self, expected_seq, actual_seq, msg=None):
-        missing = unexpected = None
-        try:
-            expected = sorted(expected_seq)
-            actual = sorted(actual_seq)
-        except TypeError:
-            # Unsortable items (example: set(), complex(), ...)
-            expected = list(expected_seq)
-            actual = list(actual_seq)
-            missing, unexpected = unorderable_list_difference(
-                expected, actual)
-        else:
-            return self.assertSequenceEqual(expected, actual, msg=msg)
-
-        errors = []
-        if missing:
-            errors.append(
-                'Expected, but missing:\n    %s' % (safe_repr(missing),)
-            )
-        if unexpected:
-            errors.append(
-                'Unexpected, but present:\n    %s' % (safe_repr(unexpected),)
-            )
-        if errors:
-            standardMsg = '\n'.join(errors)
-            self.fail(self._formatMessage(msg, standardMsg))
-
-
-class _CallableContext(object):
-
-    def __init__(self, context, cargs, ckwargs, fun):
-        self.context = context
-        self.cargs = cargs
-        self.ckwargs = ckwargs
-        self.fun = fun
-
-    def __call__(self, *args, **kwargs):
-        return self.fun(*args, **kwargs)
-
-    def __enter__(self):
-        self.ctx = self.context(*self.cargs, **self.ckwargs)
-        return self.ctx.__enter__()
-
-    def __exit__(self, *einfo):
-        if self.ctx:
-            return self.ctx.__exit__(*einfo)
-
-
-def decorator(predicate):
-
-    @wraps(predicate)
-    def take_arguments(*pargs, **pkwargs):
-
-        @wraps(predicate)
-        def decorator(cls):
-            if inspect.isclass(cls):
-                orig_setup = cls.setUp
-                orig_teardown = cls.tearDown
-
-                @wraps(cls.setUp)
-                def around_setup(*args, **kwargs):
-                    try:
-                        contexts = args[0].__rb3dc_contexts__
-                    except AttributeError:
-                        contexts = args[0].__rb3dc_contexts__ = []
-                    p = predicate(*pargs, **pkwargs)
-                    p.__enter__()
-                    contexts.append(p)
-                    return orig_setup(*args, **kwargs)
-                around_setup.__wrapped__ = cls.setUp
-                cls.setUp = around_setup
-
-                @wraps(cls.tearDown)
-                def around_teardown(*args, **kwargs):
-                    try:
-                        contexts = args[0].__rb3dc_contexts__
-                    except AttributeError:
-                        pass
-                    else:
-                        for context in contexts:
-                            context.__exit__(*sys.exc_info())
-                    orig_teardown(*args, **kwargs)
-                around_teardown.__wrapped__ = cls.tearDown
-                cls.tearDown = around_teardown
-
-                return cls
-            else:
-                @wraps(cls)
-                def around_case(*args, **kwargs):
-                    with predicate(*pargs, **pkwargs):
-                        return cls(*args, **kwargs)
-                return around_case
-
-        if len(pargs) == 1 and callable(pargs[0]):
-            fun, pargs = pargs[0], ()
-            return decorator(fun)
-        return _CallableContext(predicate, pargs, pkwargs, decorator)
-    return take_arguments
-
-
-@decorator
-@contextmanager
-def skip_unless_module(module, name=None):
-    try:
-        importlib.import_module(module)
-    except (ImportError, OSError):
-        raise SkipTest('module not installed: {0}'.format(name or module))
-    yield
-
-
-@decorator
-@contextmanager
-def skip_unless_symbol(symbol, name=None):
-    try:
-        symbol_by_name(symbol)
-    except (AttributeError, ImportError):
-        raise SkipTest('missing symbol {0}'.format(name or symbol))
-    yield
-
-
-def get_logger_handlers(logger):
-    return [
-        h for h in logger.handlers
-        if not isinstance(h, logging.NullHandler)
-    ]
-
-
-@decorator
-@contextmanager
-def wrap_logger(logger, loglevel=logging.ERROR):
-    old_handlers = get_logger_handlers(logger)
-    sio = WhateverIO()
-    siohandler = logging.StreamHandler(sio)
-    logger.handlers = [siohandler]
-
-    try:
-        yield sio
-    finally:
-        logger.handlers = old_handlers
-
-
-@decorator
-@contextmanager
-def mock_environ(env_name, env_value):
-    sentinel = object()
-    prev_val = os.environ.get(env_name, sentinel)
-    os.environ[env_name] = env_value
-    try:
-        yield env_value
-    finally:
-        if prev_val is sentinel:
-            os.environ.pop(env_name, None)
-        else:
-            os.environ[env_name] = prev_val
-
-
-@decorator
-@contextmanager
-def sleepdeprived(module=time):
-    old_sleep, module.sleep = module.sleep, noop
-    try:
-        yield
-    finally:
-        module.sleep = old_sleep
-
-
-@decorator
-@contextmanager
-def skip_if_python3(reason='incompatible'):
-    if PY3:
-        raise SkipTest('Python3: {0}'.format(reason))
-    yield
-
-
-@decorator
-@contextmanager
-def skip_if_environ(env_var_name):
-    if os.environ.get(env_var_name):
-        raise SkipTest('envvar {0} set'.format(env_var_name))
-    yield
-
-
-@decorator
-@contextmanager
-def _skip_test(reason, sign):
-    raise SkipTest('{0}: {1}'.format(sign, reason))
-    yield
-todo = partial(_skip_test, sign='TODO')
-skip = partial(_skip_test, sign='SKIP')
-
-
-# Taken from
-# http://bitbucket.org/runeh/snippets/src/tip/missing_modules.py
-@decorator
-@contextmanager
-def mask_modules(*modnames):
-    """Ban some modules from being importable inside the context
-
-    For example:
-
-        >>> with mask_modules('sys'):
-        ...     try:
-        ...         import sys
-        ...     except ImportError:
-        ...         print('sys not found')
-        sys not found
-
-        >>> import sys  # noqa
-        >>> sys.version
-        (2, 5, 2, 'final', 0)
-
-    """
-    realimport = builtins.__import__
-
-    def myimp(name, *args, **kwargs):
-        if name in modnames:
-            raise ImportError('No module named %s' % name)
-        else:
-            return realimport(name, *args, **kwargs)
-
-    builtins.__import__ = myimp
-    try:
-        yield True
-    finally:
-        builtins.__import__ = realimport
-
-
-@decorator
-@contextmanager
-def override_stdouts():
-    """Override `sys.stdout` and `sys.stderr` with `WhateverIO`."""
-    prev_out, prev_err = sys.stdout, sys.stderr
-    prev_rout, prev_rerr = sys.__stdout__, sys.__stderr__
-    mystdout, mystderr = WhateverIO(), WhateverIO()
-    sys.stdout = sys.__stdout__ = mystdout
-    sys.stderr = sys.__stderr__ = mystderr
-
-    try:
-        yield mystdout, mystderr
-    finally:
-        sys.stdout = prev_out
-        sys.stderr = prev_err
-        sys.__stdout__ = prev_rout
-        sys.__stderr__ = prev_rerr
-
-
-@decorator
-@contextmanager
-def replace_module_value(module, name, value=None):
-    has_prev = hasattr(module, name)
-    prev = getattr(module, name, None)
-    if value:
-        setattr(module, name, value)
-    else:
-        try:
-            delattr(module, name)
-        except AttributeError:
-            pass
-    try:
-        yield
-    finally:
-        if prev is not None:
-            setattr(module, name, prev)
-        if not has_prev:
-            try:
-                delattr(module, name)
-            except AttributeError:
-                pass
-pypy_version = partial(
-    replace_module_value, sys, 'pypy_version_info',
-)
-platform_pyimp = partial(
-    replace_module_value, platform, 'python_implementation',
-)
-
-
-@decorator
-@contextmanager
-def sys_platform(value):
-    prev, sys.platform = sys.platform, value
-    try:
-        yield
-    finally:
-        sys.platform = prev
-
-
-@decorator
-@contextmanager
-def reset_modules(*modules):
-    prev = {k: sys.modules.pop(k) for k in modules if k in sys.modules}
-    try:
-        yield
-    finally:
-        sys.modules.update(prev)
-
-
-@decorator
-@contextmanager
-def patch_modules(*modules):
-    prev = {}
-    for mod in modules:
-        prev[mod] = sys.modules.get(mod)
-        sys.modules[mod] = types.ModuleType(module_name_t(mod))
-    try:
-        yield
-    finally:
-        for name, mod in items(prev):
-            if mod is None:
-                sys.modules.pop(name, None)
-            else:
-                sys.modules[name] = mod
-
-
-@decorator
-@contextmanager
-def mock_module(*names):
-    prev = {}
-
-    class MockModule(types.ModuleType):
-
-        def __getattr__(self, attr):
-            setattr(self, attr, Mock())
-            return types.ModuleType.__getattribute__(self, attr)
-
-    mods = []
-    for name in names:
-        try:
-            prev[name] = sys.modules[name]
-        except KeyError:
-            pass
-        mod = sys.modules[name] = MockModule(module_name_t(name))
-        mods.append(mod)
-    try:
-        yield mods
-    finally:
-        for name in names:
-            try:
-                sys.modules[name] = prev[name]
-            except KeyError:
-                try:
-                    del(sys.modules[name])
-                except KeyError:
-                    pass
-
-
-@contextmanager
-def mock_context(mock, typ=Mock):
-    context = mock.return_value = Mock()
-    context.__enter__ = typ()
-    context.__exit__ = typ()
-
-    def on_exit(*x):
-        if x[0]:
-            reraise(x[0], x[1], x[2])
-    context.__exit__.side_effect = on_exit
-    context.__enter__.return_value = context
-    try:
-        yield context
-    finally:
-        context.reset()
-
-
-@decorator
-@contextmanager
-def mock_open(typ=WhateverIO, side_effect=None):
-    with patch(open_fqdn) as open_:
-        with mock_context(open_) as context:
-            if side_effect is not None:
-                context.__enter__.side_effect = side_effect
-            val = context.__enter__.return_value = typ()
-            val.__exit__ = Mock()
-            yield val
-
-
-@decorator
-@contextmanager
-def skip_if_platform(platform_name, name=None):
-    if sys.platform.startswith(platform_name):
-        raise SkipTest('does not work on {0}'.format(platform_name or name))
-    yield
-skip_if_jython = partial(skip_if_platform, 'java', name='Jython')
-skip_if_win32 = partial(skip_if_platform, 'win32', name='Windows')
-skip_if_darwin = partial(skip_if_platform, 'darwin', name='OS X')
-
-
-@decorator
-@contextmanager
-def skip_if_pypy():
-    if getattr(sys, 'pypy_version_info', None):
-        raise SkipTest('does not work on PyPy')
-    yield
-
-
-@decorator
-@contextmanager
-def restore_logging():
-    outs = sys.stdout, sys.stderr, sys.__stdout__, sys.__stderr__
-    root = logging.getLogger()
-    level = root.level
-    handlers = root.handlers
-
-    try:
-        yield
-    finally:
-        sys.stdout, sys.stderr, sys.__stdout__, sys.__stderr__ = outs
-        root.level = level
-        root.handlers[:] = handlers

+ 13 - 17
celery/tests/app/test_app.py

@@ -29,12 +29,8 @@ from celery.tests.case import (
     Case,
     Case,
     ContextMock,
     ContextMock,
     depends_on_current_app,
     depends_on_current_app,
-    mask_modules,
+    mock,
     patch,
     patch,
-    platform_pyimp,
-    sys_platform,
-    pypy_version,
-    mock_environ,
 )
 )
 from celery.utils import uuid
 from celery.utils import uuid
 from celery.utils.mail import ErrorMail
 from celery.utils.mail import ErrorMail
@@ -236,7 +232,7 @@ class test_App(AppCase):
             ['A', 'B', 'C', 'D', 'E', 'F'], related_name='tasks',
             ['A', 'B', 'C', 'D', 'E', 'F'], related_name='tasks',
         )
         )
 
 
-    @mock_environ('CELERY_BROKER_URL', '')
+    @mock.environ('CELERY_BROKER_URL', '')
     def test_with_broker(self):
     def test_with_broker(self):
         with self.Celery(broker='foo://baribaz') as app:
         with self.Celery(broker='foo://baribaz') as app:
             self.assertEqual(app.conf.broker_url, 'foo://baribaz')
             self.assertEqual(app.conf.broker_url, 'foo://baribaz')
@@ -850,7 +846,7 @@ class test_App(AppCase):
         self.assertIn('add2', self.app.conf.beat_schedule)
         self.assertIn('add2', self.app.conf.beat_schedule)
 
 
     def test_pool_no_multiprocessing(self):
     def test_pool_no_multiprocessing(self):
-        with mask_modules('multiprocessing.util'):
+        with mock.mask_modules('multiprocessing.util'):
             pool = self.app.pool
             pool = self.app.pool
             self.assertIs(pool, self.app._pool)
             self.assertIs(pool, self.app._pool)
 
 
@@ -953,26 +949,26 @@ class test_debugging_utils(AppCase):
 class test_pyimplementation(AppCase):
 class test_pyimplementation(AppCase):
 
 
     def test_platform_python_implementation(self):
     def test_platform_python_implementation(self):
-        with platform_pyimp(lambda: 'Xython'):
+        with mock.platform_pyimp(lambda: 'Xython'):
             self.assertEqual(pyimplementation(), 'Xython')
             self.assertEqual(pyimplementation(), 'Xython')
 
 
     def test_platform_jython(self):
     def test_platform_jython(self):
-        with platform_pyimp():
-            with sys_platform('java 1.6.51'):
+        with mock.platform_pyimp():
+            with mock.sys_platform('java 1.6.51'):
                 self.assertIn('Jython', pyimplementation())
                 self.assertIn('Jython', pyimplementation())
 
 
     def test_platform_pypy(self):
     def test_platform_pypy(self):
-        with platform_pyimp():
-            with sys_platform('darwin'):
-                with pypy_version((1, 4, 3)):
+        with mock.platform_pyimp():
+            with mock.sys_platform('darwin'):
+                with mock.pypy_version((1, 4, 3)):
                     self.assertIn('PyPy', pyimplementation())
                     self.assertIn('PyPy', pyimplementation())
-                with pypy_version((1, 4, 3, 'a4')):
+                with mock.pypy_version((1, 4, 3, 'a4')):
                     self.assertIn('PyPy', pyimplementation())
                     self.assertIn('PyPy', pyimplementation())
 
 
     def test_platform_fallback(self):
     def test_platform_fallback(self):
-        with platform_pyimp():
-            with sys_platform('darwin'):
-                with pypy_version():
+        with mock.platform_pyimp():
+            with mock.sys_platform('darwin'):
+                with mock.pypy_version():
                     self.assertEqual('CPython', pyimplementation())
                     self.assertEqual('CPython', pyimplementation())
 
 
 
 

+ 2 - 2
celery/tests/app/test_beat.py

@@ -11,7 +11,7 @@ from celery.schedules import schedule
 from celery.utils import uuid
 from celery.utils import uuid
 from celery.utils.objects import Bunch
 from celery.utils.objects import Bunch
 
 
-from celery.tests.case import AppCase, Mock, call, patch, skip_unless_module
+from celery.tests.case import AppCase, Mock, call, patch, skip
 
 
 
 
 class MockShelve(dict):
 class MockShelve(dict):
@@ -485,7 +485,7 @@ class test_Service(AppCase):
 
 
 class test_EmbeddedService(AppCase):
 class test_EmbeddedService(AppCase):
 
 
-    @skip_unless_module('_multiprocessing', name='multiprocessing')
+    @skip.unless_module('_multiprocessing', name='multiprocessing')
     def test_start_stop_process(self):
     def test_start_stop_process(self):
         from billiard.process import Process
         from billiard.process import Process
 
 

+ 9 - 9
celery/tests/app/test_defaults.py

@@ -10,7 +10,7 @@ from celery.app.defaults import (
 )
 )
 from celery.five import values
 from celery.five import values
 
 
-from celery.tests.case import AppCase, pypy_version, sys_platform
+from celery.tests.case import AppCase, mock
 
 
 
 
 class test_defaults(AppCase):
 class test_defaults(AppCase):
@@ -29,15 +29,15 @@ class test_defaults(AppCase):
         val = object()
         val = object()
         self.assertIs(self.defaults.Option.typemap['any'](val), val)
         self.assertIs(self.defaults.Option.typemap['any'](val), val)
 
 
+    @mock.sys_platform('darwin')
+    @mock.pypy_version((1, 4, 0))
     def test_default_pool_pypy_14(self):
     def test_default_pool_pypy_14(self):
-        with sys_platform('darwin'):
-            with pypy_version((1, 4, 0)):
-                self.assertEqual(self.defaults.DEFAULT_POOL, 'solo')
+        self.assertEqual(self.defaults.DEFAULT_POOL, 'solo')
 
 
+    @mock.sys_platform('darwin')
+    @mock.pypy_version((1, 5, 0))
     def test_default_pool_pypy_15(self):
     def test_default_pool_pypy_15(self):
-        with sys_platform('darwin'):
-            with pypy_version((1, 5, 0)):
-                self.assertEqual(self.defaults.DEFAULT_POOL, 'prefork')
+        self.assertEqual(self.defaults.DEFAULT_POOL, 'prefork')
 
 
     def test_compat_indices(self):
     def test_compat_indices(self):
         self.assertFalse(any(key.isupper() for key in DEFAULTS))
         self.assertFalse(any(key.isupper() for key in DEFAULTS))
@@ -54,9 +54,9 @@ class test_defaults(AppCase):
         for key in _TO_OLD_KEY:
         for key in _TO_OLD_KEY:
             self.assertIn(key, SETTING_KEYS)
             self.assertIn(key, SETTING_KEYS)
 
 
+    @mock.sys_platform('java 1.6.51')
     def test_default_pool_jython(self):
     def test_default_pool_jython(self):
-        with sys_platform('java 1.6.51'):
-            self.assertEqual(self.defaults.DEFAULT_POOL, 'threads')
+        self.assertEqual(self.defaults.DEFAULT_POOL, 'threads')
 
 
     def test_find(self):
     def test_find(self):
         find = self.defaults.find
         find = self.defaults.find

+ 2 - 2
celery/tests/app/test_loaders.py

@@ -13,7 +13,7 @@ from celery.loaders.app import AppLoader
 from celery.utils.imports import NotAPackage
 from celery.utils.imports import NotAPackage
 from celery.utils.mail import SendmailWarning
 from celery.utils.mail import SendmailWarning
 
 
-from celery.tests.case import AppCase, Case, Mock, mock_environ, patch
+from celery.tests.case import AppCase, Case, Mock, mock, patch
 
 
 
 
 class DummyLoader(base.BaseLoader):
 class DummyLoader(base.BaseLoader):
@@ -144,7 +144,7 @@ class test_DefaultLoader(AppCase):
             l.read_configuration(fail_silently=False)
             l.read_configuration(fail_silently=False)
 
 
     @patch('celery.loaders.base.find_module')
     @patch('celery.loaders.base.find_module')
-    @mock_environ('CELERY_CONFIG_MODULE', 'celeryconfig.py')
+    @mock.environ('CELERY_CONFIG_MODULE', 'celeryconfig.py')
     def test_read_configuration_py_in_name(self, find_module):
     def test_read_configuration_py_in_name(self, find_module):
         find_module.side_effect = NotAPackage()
         find_module.side_effect = NotAPackage()
         l = default.Loader(app=self.app)
         l = default.Loader(app=self.app)

+ 86 - 88
celery/tests/app/test_log.py

@@ -20,11 +20,9 @@ from celery.utils.log import (
     in_sighandler,
     in_sighandler,
     logger_isa,
     logger_isa,
 )
 )
-from celery.tests.case import (
-    AppCase, Mock, mask_modules, skip_if_python3,
-    override_stdouts, patch, wrap_logger, restore_logging,
-)
-from celery.tests._case import get_logger_handlers
+
+from case.utils import get_logger_handlers
+from celery.tests.case import AppCase, Mock, mock, patch, skip
 
 
 
 
 class test_TaskFormatter(AppCase):
 class test_TaskFormatter(AppCase):
@@ -156,7 +154,7 @@ class test_ColorFormatter(AppCase):
         self.assertIn('<Unrepresentable', msg)
         self.assertIn('<Unrepresentable', msg)
         self.assertEqual(safe_str.call_count, 1)
         self.assertEqual(safe_str.call_count, 1)
 
 
-    @skip_if_python3()
+    @skip.if_python3()
     @patch('celery.utils.log.safe_str')
     @patch('celery.utils.log.safe_str')
     def test_format_raises_no_color(self, safe_str):
     def test_format_raises_no_color(self, safe_str):
         x = ColorFormatter(use_color=False)
         x = ColorFormatter(use_color=False)
@@ -184,14 +182,14 @@ class test_default_logger(AppCase):
         logger = get_logger(base_logger.name)
         logger = get_logger(base_logger.name)
         self.assertIs(logger.parent, logging.root)
         self.assertIs(logger.parent, logging.root)
 
 
+    @mock.restore_logging()
     def test_setup_logging_subsystem_misc(self):
     def test_setup_logging_subsystem_misc(self):
-        with restore_logging():
-            self.app.log.setup_logging_subsystem(loglevel=None)
+        self.app.log.setup_logging_subsystem(loglevel=None)
 
 
+    @mock.restore_logging()
     def test_setup_logging_subsystem_misc2(self):
     def test_setup_logging_subsystem_misc2(self):
-        with restore_logging():
-            self.app.conf.worker_hijack_root_logger = True
-            self.app.log.setup_logging_subsystem()
+        self.app.conf.worker_hijack_root_logger = True
+        self.app.log.setup_logging_subsystem()
 
 
     def test_get_default_logger(self):
     def test_get_default_logger(self):
         self.assertTrue(self.app.log.get_default_logger())
         self.assertTrue(self.app.log.get_default_logger())
@@ -202,19 +200,19 @@ class test_default_logger(AppCase):
         self.app.log._configure_logger(None, sys.stderr, None, '', False)
         self.app.log._configure_logger(None, sys.stderr, None, '', False)
         logger.handlers[:] = []
         logger.handlers[:] = []
 
 
+    @mock.restore_logging()
     def test_setup_logging_subsystem_colorize(self):
     def test_setup_logging_subsystem_colorize(self):
-        with restore_logging():
-            self.app.log.setup_logging_subsystem(colorize=None)
-            self.app.log.setup_logging_subsystem(colorize=True)
+        self.app.log.setup_logging_subsystem(colorize=None)
+        self.app.log.setup_logging_subsystem(colorize=True)
 
 
+    @mock.restore_logging()
     def test_setup_logging_subsystem_no_mputil(self):
     def test_setup_logging_subsystem_no_mputil(self):
-        with restore_logging():
-            with mask_modules('billiard.util'):
-                self.app.log.setup_logging_subsystem()
+        with mock.mask_modules('billiard.util'):
+            self.app.log.setup_logging_subsystem()
 
 
     def _assertLog(self, logger, logmsg, loglevel=logging.ERROR):
     def _assertLog(self, logger, logmsg, loglevel=logging.ERROR):
 
 
-        with wrap_logger(logger, loglevel=loglevel) as sio:
+        with mock.wrap_logger(logger, loglevel=loglevel) as sio:
             logger.log(loglevel, logmsg)
             logger.log(loglevel, logmsg)
             return sio.getvalue().strip()
             return sio.getvalue().strip()
 
 
@@ -226,30 +224,30 @@ class test_default_logger(AppCase):
         val = self._assertLog(logger, logmsg, loglevel=loglevel)
         val = self._assertLog(logger, logmsg, loglevel=loglevel)
         return self.assertFalse(val, reason)
         return self.assertFalse(val, reason)
 
 
+    @mock.restore_logging()
     def test_setup_logger(self):
     def test_setup_logger(self):
-        with restore_logging():
-            logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
-                                       root=False, colorize=True)
-            logger.handlers = []
-            self.app.log.already_setup = False
-            logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
-                                       root=False, colorize=None)
-            self.assertIs(
-                get_logger_handlers(logger)[0].stream, sys.__stderr__,
-                'setup_logger logs to stderr without logfile argument.',
-            )
-
+        logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
+                                   root=False, colorize=True)
+        logger.handlers = []
+        self.app.log.already_setup = False
+        logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
+                                   root=False, colorize=None)
+        self.assertIs(
+            get_logger_handlers(logger)[0].stream, sys.__stderr__,
+            'setup_logger logs to stderr without logfile argument.',
+        )
+
+    @mock.restore_logging()
     def test_setup_logger_no_handlers_stream(self):
     def test_setup_logger_no_handlers_stream(self):
-        with restore_logging():
-            l = self.get_logger()
-            l.handlers = []
+        l = self.get_logger()
+        l.handlers = []
 
 
-            with override_stdouts() as outs:
-                stdout, stderr = outs
-                l = self.setup_logger(logfile=sys.stderr,
-                                      loglevel=logging.INFO, root=False)
-                l.info('The quick brown fox...')
-                self.assertIn('The quick brown fox...', stderr.getvalue())
+        with mock.stdouts() as outs:
+            stdout, stderr = outs
+            l = self.setup_logger(logfile=sys.stderr,
+                                  loglevel=logging.INFO, root=False)
+            l.info('The quick brown fox...')
+            self.assertIn('The quick brown fox...', stderr.getvalue())
 
 
     @patch('os.fstat')
     @patch('os.fstat')
     def test_setup_logger_no_handlers_file(self, *args):
     def test_setup_logger_no_handlers_file(self, *args):
@@ -257,7 +255,7 @@ class test_default_logger(AppCase):
         _open = ('builtins.open' if sys.version_info[0] == 3
         _open = ('builtins.open' if sys.version_info[0] == 3
                  else '__builtin__.open')
                  else '__builtin__.open')
         with patch(_open) as osopen:
         with patch(_open) as osopen:
-            with restore_logging():
+            with mock.restore_logging():
                 files = defaultdict(StringIO)
                 files = defaultdict(StringIO)
 
 
                 def open_file(filename, *args, **kwargs):
                 def open_file(filename, *args, **kwargs):
@@ -277,59 +275,59 @@ class test_default_logger(AppCase):
                 )
                 )
                 self.assertIn(tempfile, files)
                 self.assertIn(tempfile, files)
 
 
+    @mock.restore_logging()
     def test_redirect_stdouts(self):
     def test_redirect_stdouts(self):
-        with restore_logging():
-            logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
-                                       root=False)
-            try:
-                with wrap_logger(logger) as sio:
-                    self.app.log.redirect_stdouts_to_logger(
-                        logger, loglevel=logging.ERROR,
-                    )
-                    logger.error('foo')
-                    self.assertIn('foo', sio.getvalue())
-                    self.app.log.redirect_stdouts_to_logger(
-                        logger, stdout=False, stderr=False,
-                    )
-            finally:
-                sys.stdout, sys.stderr = sys.__stdout__, sys.__stderr__
+        logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
+                                   root=False)
+        try:
+            with mock.wrap_logger(logger) as sio:
+                self.app.log.redirect_stdouts_to_logger(
+                    logger, loglevel=logging.ERROR,
+                )
+                logger.error('foo')
+                self.assertIn('foo', sio.getvalue())
+                self.app.log.redirect_stdouts_to_logger(
+                    logger, stdout=False, stderr=False,
+                )
+        finally:
+            sys.stdout, sys.stderr = sys.__stdout__, sys.__stderr__
 
 
+    @mock.restore_logging()
     def test_logging_proxy(self):
     def test_logging_proxy(self):
-        with restore_logging():
-            logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
-                                       root=False)
-
-            with wrap_logger(logger) as sio:
-                p = LoggingProxy(logger, loglevel=logging.ERROR)
-                p.close()
-                p.write('foo')
-                self.assertNotIn('foo', sio.getvalue())
-                p.closed = False
-                p.write('foo')
-                self.assertIn('foo', sio.getvalue())
-                lines = ['baz', 'xuzzy']
-                p.writelines(lines)
-                for line in lines:
-                    self.assertIn(line, sio.getvalue())
-                p.flush()
-                p.close()
-                self.assertFalse(p.isatty())
-
-                with override_stdouts() as (stdout, stderr):
-                    with in_sighandler():
-                        p.write('foo')
-                        self.assertTrue(stderr.getvalue())
+        logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
+                                   root=False)
 
 
-    def test_logging_proxy_recurse_protection(self):
-        with restore_logging():
-            logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
-                                       root=False)
+        with mock.wrap_logger(logger) as sio:
             p = LoggingProxy(logger, loglevel=logging.ERROR)
             p = LoggingProxy(logger, loglevel=logging.ERROR)
-            p._thread.recurse_protection = True
-            try:
-                self.assertIsNone(p.write('FOOFO'))
-            finally:
-                p._thread.recurse_protection = False
+            p.close()
+            p.write('foo')
+            self.assertNotIn('foo', sio.getvalue())
+            p.closed = False
+            p.write('foo')
+            self.assertIn('foo', sio.getvalue())
+            lines = ['baz', 'xuzzy']
+            p.writelines(lines)
+            for line in lines:
+                self.assertIn(line, sio.getvalue())
+            p.flush()
+            p.close()
+            self.assertFalse(p.isatty())
+
+            with mock.stdouts() as (stdout, stderr):
+                with in_sighandler():
+                    p.write('foo')
+                    self.assertTrue(stderr.getvalue())
+
+    @mock.restore_logging()
+    def test_logging_proxy_recurse_protection(self):
+        logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
+                                   root=False)
+        p = LoggingProxy(logger, loglevel=logging.ERROR)
+        p._thread.recurse_protection = True
+        try:
+            self.assertIsNone(p.write('FOOFO'))
+        finally:
+            p._thread.recurse_protection = False
 
 
 
 
 class test_task_logger(test_default_logger):
 class test_task_logger(test_default_logger):

+ 3 - 3
celery/tests/app/test_schedules.py

@@ -10,7 +10,7 @@ from celery.five import items
 from celery.schedules import (
 from celery.schedules import (
     ParseException, crontab, crontab_parser, schedule, solar,
     ParseException, crontab, crontab_parser, schedule, solar,
 )
 )
-from celery.tests.case import AppCase, Mock, skip_unless_module, todo
+from celery.tests.case import AppCase, Mock, skip
 
 
 
 
 @contextmanager
 @contextmanager
@@ -23,7 +23,7 @@ def patch_crontab_nowfun(cls, retval):
         cls.nowfun = prev_nowfun
         cls.nowfun = prev_nowfun
 
 
 
 
-@skip_unless_module('ephem')
+@skip.unless_module('ephem')
 class test_solar(AppCase):
 class test_solar(AppCase):
 
 
     def setup(self):
     def setup(self):
@@ -735,7 +735,7 @@ class test_crontab_is_due(AppCase):
             self.assertTrue(due)
             self.assertTrue(due)
             self.assertEqual(remaining, 60.)
             self.assertEqual(remaining, 60.)
 
 
-    @todo('unstable test')
+    @skip.todo('unstable test')
     def test_monthly_moy_execution_is_not_due(self):
     def test_monthly_moy_execution_is_not_due(self):
         with patch_crontab_nowfun(
         with patch_crontab_nowfun(
                 self.monthly_moy, datetime(2013, 6, 28, 14, 30)):
                 self.monthly_moy, datetime(2013, 6, 28, 14, 30)):

+ 2 - 4
celery/tests/backends/test_amqp.py

@@ -14,9 +14,7 @@ from celery.five import Empty, Queue, range
 from celery.result import AsyncResult
 from celery.result import AsyncResult
 from celery.utils import uuid
 from celery.utils import uuid
 
 
-from celery.tests.case import (
-    AppCase, Mock, depends_on_current_app, sleepdeprived,
-)
+from celery.tests.case import AppCase, Mock, depends_on_current_app, mock
 
 
 
 
 class SomeClass(object):
 class SomeClass(object):
@@ -112,7 +110,7 @@ class test_AMQPBackend(AppCase):
         b = self.create_backend(expires=timedelta(minutes=1))
         b = self.create_backend(expires=timedelta(minutes=1))
         self.assertEqual(b.queue_arguments.get('x-expires'), 60 * 1000.0)
         self.assertEqual(b.queue_arguments.get('x-expires'), 60 * 1000.0)
 
 
-    @sleepdeprived()
+    @mock.sleepdeprived()
     def test_store_result_retries(self):
     def test_store_result_retries(self):
         iterations = [0]
         iterations = [0]
         stop_raising_at = [5]
         stop_raising_at = [5]

+ 3 - 4
celery/tests/backends/test_base.py

@@ -25,9 +25,7 @@ from celery.result import result_from_tuple
 from celery.utils import uuid
 from celery.utils import uuid
 from celery.utils.functional import pass1
 from celery.utils.functional import pass1
 
 
-from celery.tests.case import (
-    ANY, AppCase, Case, Mock, call, patch, skip_if_python3,
-)
+from celery.tests.case import ANY, AppCase, Case, Mock, call, patch, skip
 
 
 
 
 class wrapobject(object):
 class wrapobject(object):
@@ -94,7 +92,8 @@ class test_BaseBackend_interface(AppCase):
 
 
 class test_exception_pickle(AppCase):
 class test_exception_pickle(AppCase):
 
 
-    @skip_if_python3('does not support old style classes')
+    @skip.if_python3(reason='does not support old style classes')
+    @skip.if_pypy()
     def test_oldstyle(self):
     def test_oldstyle(self):
         self.assertTrue(fnpe(Oldstyle()))
         self.assertTrue(fnpe(Oldstyle()))
 
 

+ 15 - 17
celery/tests/backends/test_cache.py

@@ -15,9 +15,7 @@ from celery.exceptions import ImproperlyConfigured
 from celery.five import items, module_name_t, string, text_t
 from celery.five import items, module_name_t, string, text_t
 from celery.utils import uuid
 from celery.utils import uuid
 
 
-from celery.tests.case import (
-    AppCase, Mock, override_stdouts, mask_modules, patch, reset_modules,
-)
+from celery.tests.case import AppCase, Mock, mock, patch
 
 
 PY3 = sys.version_info[0] == 3
 PY3 = sys.version_info[0] == 3
 
 
@@ -136,8 +134,8 @@ class test_CacheBackend(AppCase):
         b = CacheBackend(backend=backend, app=self.app)
         b = CacheBackend(backend=backend, app=self.app)
         self.assertEqual(b.as_uri(), backend)
         self.assertEqual(b.as_uri(), backend)
 
 
-    @override_stdouts
-    def test_regression_worker_startup_info(self):
+    @mock.stdouts
+    def test_regression_worker_startup_info(self, stdout, stderr):
         self.app.conf.result_backend = (
         self.app.conf.result_backend = (
             'cache+memcached://127.0.0.1:11211;127.0.0.2:11211;127.0.0.3/'
             'cache+memcached://127.0.0.1:11211;127.0.0.2:11211;127.0.0.3/'
         )
         )
@@ -197,7 +195,7 @@ class test_get_best_memcache(AppCase, MockCacheMixin):
 
 
     def test_pylibmc(self):
     def test_pylibmc(self):
         with self.mock_pylibmc():
         with self.mock_pylibmc():
-            with reset_modules('celery.backends.cache'):
+            with mock.reset_modules('celery.backends.cache'):
                 from celery.backends import cache
                 from celery.backends import cache
                 cache._imp = [None]
                 cache._imp = [None]
                 self.assertEqual(cache.get_best_memcache()[0].__module__,
                 self.assertEqual(cache.get_best_memcache()[0].__module__,
@@ -205,16 +203,16 @@ class test_get_best_memcache(AppCase, MockCacheMixin):
 
 
     def test_memcache(self):
     def test_memcache(self):
         with self.mock_memcache():
         with self.mock_memcache():
-            with reset_modules('celery.backends.cache'):
-                with mask_modules('pylibmc'):
+            with mock.reset_modules('celery.backends.cache'):
+                with mock.mask_modules('pylibmc'):
                     from celery.backends import cache
                     from celery.backends import cache
                     cache._imp = [None]
                     cache._imp = [None]
                     self.assertEqual(cache.get_best_memcache()[0]().__module__,
                     self.assertEqual(cache.get_best_memcache()[0]().__module__,
                                      'memcache')
                                      'memcache')
 
 
     def test_no_implementations(self):
     def test_no_implementations(self):
-        with mask_modules('pylibmc', 'memcache'):
-            with reset_modules('celery.backends.cache'):
+        with mock.mask_modules('pylibmc', 'memcache'):
+            with mock.reset_modules('celery.backends.cache'):
                 from celery.backends import cache
                 from celery.backends import cache
                 cache._imp = [None]
                 cache._imp = [None]
                 with self.assertRaises(ImproperlyConfigured):
                 with self.assertRaises(ImproperlyConfigured):
@@ -222,7 +220,7 @@ class test_get_best_memcache(AppCase, MockCacheMixin):
 
 
     def test_cached(self):
     def test_cached(self):
         with self.mock_pylibmc():
         with self.mock_pylibmc():
-            with reset_modules('celery.backends.cache'):
+            with mock.reset_modules('celery.backends.cache'):
                 from celery.backends import cache
                 from celery.backends import cache
                 cache._imp = [None]
                 cache._imp = [None]
                 cache.get_best_memcache()[0](behaviors={'foo': 'bar'})
                 cache.get_best_memcache()[0](behaviors={'foo': 'bar'})
@@ -240,8 +238,8 @@ class test_memcache_key(AppCase, MockCacheMixin):
 
 
     def test_memcache_unicode_key(self):
     def test_memcache_unicode_key(self):
         with self.mock_memcache():
         with self.mock_memcache():
-            with reset_modules('celery.backends.cache'):
-                with mask_modules('pylibmc'):
+            with mock.reset_modules('celery.backends.cache'):
+                with mock.mask_modules('pylibmc'):
                     from celery.backends import cache
                     from celery.backends import cache
                     cache._imp = [None]
                     cache._imp = [None]
                     task_id, result = string(uuid()), 42
                     task_id, result = string(uuid()), 42
@@ -251,8 +249,8 @@ class test_memcache_key(AppCase, MockCacheMixin):
 
 
     def test_memcache_bytes_key(self):
     def test_memcache_bytes_key(self):
         with self.mock_memcache():
         with self.mock_memcache():
-            with reset_modules('celery.backends.cache'):
-                with mask_modules('pylibmc'):
+            with mock.reset_modules('celery.backends.cache'):
+                with mock.mask_modules('pylibmc'):
                     from celery.backends import cache
                     from celery.backends import cache
                     cache._imp = [None]
                     cache._imp = [None]
                     task_id, result = str_to_bytes(uuid()), 42
                     task_id, result = str_to_bytes(uuid()), 42
@@ -261,7 +259,7 @@ class test_memcache_key(AppCase, MockCacheMixin):
                     self.assertEqual(b.get_result(task_id), result)
                     self.assertEqual(b.get_result(task_id), result)
 
 
     def test_pylibmc_unicode_key(self):
     def test_pylibmc_unicode_key(self):
-        with reset_modules('celery.backends.cache'):
+        with mock.reset_modules('celery.backends.cache'):
             with self.mock_pylibmc():
             with self.mock_pylibmc():
                 from celery.backends import cache
                 from celery.backends import cache
                 cache._imp = [None]
                 cache._imp = [None]
@@ -271,7 +269,7 @@ class test_memcache_key(AppCase, MockCacheMixin):
                 self.assertEqual(b.get_result(task_id), result)
                 self.assertEqual(b.get_result(task_id), result)
 
 
     def test_pylibmc_bytes_key(self):
     def test_pylibmc_bytes_key(self):
-        with reset_modules('celery.backends.cache'):
+        with mock.reset_modules('celery.backends.cache'):
             with self.mock_pylibmc():
             with self.mock_pylibmc():
                 from celery.backends import cache
                 from celery.backends import cache
                 cache._imp = [None]
                 cache._imp = [None]

+ 130 - 140
celery/tests/backends/test_cassandra.py

@@ -6,13 +6,12 @@ from datetime import datetime
 from celery import states
 from celery import states
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
 from celery.utils.objects import Bunch
 from celery.utils.objects import Bunch
-from celery.tests.case import (
-    AppCase, Mock, mock_module, depends_on_current_app
-)
+from celery.tests.case import AppCase, Mock, depends_on_current_app, mock
 
 
 CASSANDRA_MODULES = ['cassandra', 'cassandra.auth', 'cassandra.cluster']
 CASSANDRA_MODULES = ['cassandra', 'cassandra.auth', 'cassandra.cluster']
 
 
 
 
+@mock.module(*CASSANDRA_MODULES)
 class test_CassandraBackend(AppCase):
 class test_CassandraBackend(AppCase):
 
 
     def setup(self):
     def setup(self):
@@ -22,174 +21,165 @@ class test_CassandraBackend(AppCase):
             cassandra_table='task_results',
             cassandra_table='task_results',
         )
         )
 
 
-    def test_init_no_cassandra(self):
-        """should raise ImproperlyConfigured when no python-driver
-        installed."""
-        with mock_module(*CASSANDRA_MODULES):
-            from celery.backends import cassandra as mod
-            prev, mod.cassandra = mod.cassandra, None
-            try:
-                with self.assertRaises(ImproperlyConfigured):
-                    mod.CassandraBackend(app=self.app)
-            finally:
-                mod.cassandra = prev
-
-    def test_init_with_and_without_LOCAL_QUROM(self):
-        with mock_module(*CASSANDRA_MODULES):
-            from celery.backends import cassandra as mod
-            mod.cassandra = Mock()
-
-            cons = mod.cassandra.ConsistencyLevel = Bunch(
-                LOCAL_QUORUM='foo',
-            )
-
-            self.app.conf.cassandra_read_consistency = 'LOCAL_FOO'
-            self.app.conf.cassandra_write_consistency = 'LOCAL_FOO'
-
-            mod.CassandraBackend(app=self.app)
-            cons.LOCAL_FOO = 'bar'
-            mod.CassandraBackend(app=self.app)
-
-            # no servers raises ImproperlyConfigured
+    def test_init_no_cassandra(self, *modules):
+        # should raise ImproperlyConfigured when no python-driver
+        # installed.
+        from celery.backends import cassandra as mod
+        prev, mod.cassandra = mod.cassandra, None
+        try:
             with self.assertRaises(ImproperlyConfigured):
             with self.assertRaises(ImproperlyConfigured):
-                self.app.conf.cassandra_servers = None
-                mod.CassandraBackend(
-                    app=self.app, keyspace='b', column_family='c',
-                )
+                mod.CassandraBackend(app=self.app)
+        finally:
+            mod.cassandra = prev
 
 
-    @depends_on_current_app
-    def test_reduce(self):
-        with mock_module(*CASSANDRA_MODULES):
-            from celery.backends.cassandra import CassandraBackend
-            self.assertTrue(loads(dumps(CassandraBackend(app=self.app))))
+    def test_init_with_and_without_LOCAL_QUROM(self, *modules):
+        from celery.backends import cassandra as mod
+        mod.cassandra = Mock()
 
 
-    def test_get_task_meta_for(self):
-        with mock_module(*CASSANDRA_MODULES):
-            from celery.backends import cassandra as mod
-            mod.cassandra = Mock()
+        cons = mod.cassandra.ConsistencyLevel = Bunch(
+            LOCAL_QUORUM='foo',
+        )
 
 
-            x = mod.CassandraBackend(app=self.app)
-            x._connection = True
-            session = x._session = Mock()
-            execute = session.execute = Mock()
-            execute.return_value = [
-                [states.SUCCESS, '1', datetime.now(), b'', b'']
-            ]
-            x.decode = Mock()
-            meta = x._get_task_meta_for('task_id')
-            self.assertEqual(meta['status'], states.SUCCESS)
-
-            x._session.execute.return_value = []
-            meta = x._get_task_meta_for('task_id')
-            self.assertEqual(meta['status'], states.PENDING)
-
-    def test_store_result(self):
-        with mock_module(*CASSANDRA_MODULES):
-            from celery.backends import cassandra as mod
-            mod.cassandra = Mock()
+        self.app.conf.cassandra_read_consistency = 'LOCAL_FOO'
+        self.app.conf.cassandra_write_consistency = 'LOCAL_FOO'
 
 
-            x = mod.CassandraBackend(app=self.app)
-            x._connection = True
-            session = x._session = Mock()
-            session.execute = Mock()
-            x._store_result('task_id', 'result', states.SUCCESS)
+        mod.CassandraBackend(app=self.app)
+        cons.LOCAL_FOO = 'bar'
+        mod.CassandraBackend(app=self.app)
 
 
-    def test_process_cleanup(self):
-        with mock_module(*CASSANDRA_MODULES):
-            from celery.backends import cassandra as mod
-            x = mod.CassandraBackend(app=self.app)
-            x.process_cleanup()
+        # no servers raises ImproperlyConfigured
+        with self.assertRaises(ImproperlyConfigured):
+            self.app.conf.cassandra_servers = None
+            mod.CassandraBackend(
+                app=self.app, keyspace='b', column_family='c',
+            )
 
 
-            self.assertIsNone(x._connection)
-            self.assertIsNone(x._session)
+    @depends_on_current_app
+    def test_reduce(self, *modules):
+        from celery.backends.cassandra import CassandraBackend
+        self.assertTrue(loads(dumps(CassandraBackend(app=self.app))))
+
+    def test_get_task_meta_for(self, *modules):
+        from celery.backends import cassandra as mod
+        mod.cassandra = Mock()
+
+        x = mod.CassandraBackend(app=self.app)
+        x._connection = True
+        session = x._session = Mock()
+        execute = session.execute = Mock()
+        execute.return_value = [
+            [states.SUCCESS, '1', datetime.now(), b'', b'']
+        ]
+        x.decode = Mock()
+        meta = x._get_task_meta_for('task_id')
+        self.assertEqual(meta['status'], states.SUCCESS)
+
+        x._session.execute.return_value = []
+        meta = x._get_task_meta_for('task_id')
+        self.assertEqual(meta['status'], states.PENDING)
+
+    def test_store_result(self, *modules):
+        from celery.backends import cassandra as mod
+        mod.cassandra = Mock()
+
+        x = mod.CassandraBackend(app=self.app)
+        x._connection = True
+        session = x._session = Mock()
+        session.execute = Mock()
+        x._store_result('task_id', 'result', states.SUCCESS)
+
+    def test_process_cleanup(self, *modules):
+        from celery.backends import cassandra as mod
+        x = mod.CassandraBackend(app=self.app)
+        x.process_cleanup()
+
+        self.assertIsNone(x._connection)
+        self.assertIsNone(x._session)
 
 
     def test_timeouting_cluster(self):
     def test_timeouting_cluster(self):
-        """Tests behaviour when Cluster.connect raises
-        cassandra.OperationTimedOut."""
-        with mock_module(*CASSANDRA_MODULES):
-            from celery.backends import cassandra as mod
+        # Tests behaviour when Cluster.connect raises
+        # cassandra.OperationTimedOut.
+        from celery.backends import cassandra as mod
 
 
-            class OTOExc(Exception):
-                pass
+        class OTOExc(Exception):
+            pass
 
 
-            class VeryFaultyCluster(object):
-                def __init__(self, *args, **kwargs):
-                    pass
+        class VeryFaultyCluster(object):
+            def __init__(self, *args, **kwargs):
+                pass
 
 
-                def connect(self, *args, **kwargs):
-                    raise OTOExc()
+            def connect(self, *args, **kwargs):
+                raise OTOExc()
 
 
-                def shutdown(self):
-                    pass
+            def shutdown(self):
+                pass
 
 
-            mod.cassandra = Mock()
-            mod.cassandra.OperationTimedOut = OTOExc
-            mod.cassandra.cluster = Mock()
-            mod.cassandra.cluster.Cluster = VeryFaultyCluster
+        mod.cassandra = Mock()
+        mod.cassandra.OperationTimedOut = OTOExc
+        mod.cassandra.cluster = Mock()
+        mod.cassandra.cluster.Cluster = VeryFaultyCluster
 
 
-            x = mod.CassandraBackend(app=self.app)
+        x = mod.CassandraBackend(app=self.app)
 
 
-            with self.assertRaises(OTOExc):
-                x._store_result('task_id', 'result', states.SUCCESS)
-            self.assertIsNone(x._connection)
-            self.assertIsNone(x._session)
+        with self.assertRaises(OTOExc):
+            x._store_result('task_id', 'result', states.SUCCESS)
+        self.assertIsNone(x._connection)
+        self.assertIsNone(x._session)
 
 
-            x.process_cleanup()  # should not raise
+        x.process_cleanup()  # should not raise
 
 
     def test_please_free_memory(self):
     def test_please_free_memory(self):
-        """Ensure that Cluster object IS shut down."""
-        with mock_module(*CASSANDRA_MODULES):
-            from celery.backends import cassandra as mod
+        # Ensure that Cluster object IS shut down.
+        from celery.backends import cassandra as mod
 
 
-            class RAMHoggingCluster(object):
+        class RAMHoggingCluster(object):
 
 
-                objects_alive = 0
+            objects_alive = 0
 
 
-                def __init__(self, *args, **kwargs):
-                    pass
+            def __init__(self, *args, **kwargs):
+                pass
 
 
-                def connect(self, *args, **kwargs):
-                    RAMHoggingCluster.objects_alive += 1
-                    return Mock()
+            def connect(self, *args, **kwargs):
+                RAMHoggingCluster.objects_alive += 1
+                return Mock()
 
 
-                def shutdown(self):
-                    RAMHoggingCluster.objects_alive -= 1
+            def shutdown(self):
+                RAMHoggingCluster.objects_alive -= 1
 
 
-            mod.cassandra = Mock()
+        mod.cassandra = Mock()
 
 
-            mod.cassandra.cluster = Mock()
-            mod.cassandra.cluster.Cluster = RAMHoggingCluster
+        mod.cassandra.cluster = Mock()
+        mod.cassandra.cluster.Cluster = RAMHoggingCluster
 
 
-            for x in range(0, 10):
-                x = mod.CassandraBackend(app=self.app)
-                x._store_result('task_id', 'result', states.SUCCESS)
-                x.process_cleanup()
+        for x in range(0, 10):
+            x = mod.CassandraBackend(app=self.app)
+            x._store_result('task_id', 'result', states.SUCCESS)
+            x.process_cleanup()
 
 
-            self.assertEquals(RAMHoggingCluster.objects_alive, 0)
+        self.assertEquals(RAMHoggingCluster.objects_alive, 0)
 
 
     def test_auth_provider(self):
     def test_auth_provider(self):
-        """Ensure valid auth_provider works properly, and invalid one raises
-        ImproperlyConfigured exception."""
+        # Ensure valid auth_provider works properly, and invalid one raises
+        # ImproperlyConfigured exception.
+        from celery.backends import cassandra as mod
+
         class DummyAuth(object):
         class DummyAuth(object):
             ValidAuthProvider = Mock()
             ValidAuthProvider = Mock()
 
 
-        with mock_module(*CASSANDRA_MODULES):
-            from celery.backends import cassandra as mod
-
-            mod.cassandra = Mock()
-            mod.cassandra.auth = DummyAuth
-
-            # Valid auth_provider
-            self.app.conf.cassandra_auth_provider = 'ValidAuthProvider'
-            self.app.conf.cassandra_auth_kwargs = {
-                'username': 'stuff'
-            }
+        mod.cassandra = Mock()
+        mod.cassandra.auth = DummyAuth
+
+        # Valid auth_provider
+        self.app.conf.cassandra_auth_provider = 'ValidAuthProvider'
+        self.app.conf.cassandra_auth_kwargs = {
+            'username': 'stuff'
+        }
+        mod.CassandraBackend(app=self.app)
+
+        # Invalid auth_provider
+        self.app.conf.cassandra_auth_provider = 'SpiderManAuth'
+        self.app.conf.cassandra_auth_kwargs = {
+            'username': 'Jack'
+        }
+        with self.assertRaises(ImproperlyConfigured):
             mod.CassandraBackend(app=self.app)
             mod.CassandraBackend(app=self.app)
-
-            # Invalid auth_provider
-            self.app.conf.cassandra_auth_provider = 'SpiderManAuth'
-            self.app.conf.cassandra_auth_kwargs = {
-                'username': 'Jack'
-            }
-            with self.assertRaises(ImproperlyConfigured):
-                mod.CassandraBackend(app=self.app)

+ 2 - 4
celery/tests/backends/test_couchbase.py

@@ -8,9 +8,7 @@ from celery.backends import couchbase as module
 from celery.backends.couchbase import CouchBaseBackend
 from celery.backends.couchbase import CouchBaseBackend
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
 from celery import backends
 from celery import backends
-from celery.tests.case import (
-    AppCase, MagicMock, Mock, patch, sentinel, skip_unless_module,
-)
+from celery.tests.case import AppCase, MagicMock, Mock, patch, sentinel, skip
 
 
 try:
 try:
     import couchbase
     import couchbase
@@ -20,7 +18,7 @@ except ImportError:
 COUCHBASE_BUCKET = 'celery_bucket'
 COUCHBASE_BUCKET = 'celery_bucket'
 
 
 
 
-@skip_unless_module('couchbase')
+@skip.unless_module('couchbase')
 class test_CouchBaseBackend(AppCase):
 class test_CouchBaseBackend(AppCase):
 
 
     def setup(self):
     def setup(self):

+ 2 - 4
celery/tests/backends/test_couchdb.py

@@ -4,9 +4,7 @@ from celery.backends import couchdb as module
 from celery.backends.couchdb import CouchBackend
 from celery.backends.couchdb import CouchBackend
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
 from celery import backends
 from celery import backends
-from celery.tests.case import (
-    AppCase, Mock, patch, sentinel, skip_unless_module,
-)
+from celery.tests.case import AppCase, Mock, patch, sentinel, skip
 
 
 try:
 try:
     import pycouchdb
     import pycouchdb
@@ -16,7 +14,7 @@ except ImportError:
 COUCHDB_CONTAINER = 'celery_container'
 COUCHDB_CONTAINER = 'celery_container'
 
 
 
 
-@skip_unless_module('pycouchdb')
+@skip.unless_module('pycouchdb')
 class test_CouchBackend(AppCase):
 class test_CouchBackend(AppCase):
 
 
     def setup(self):
     def setup(self):

+ 6 - 12
celery/tests/backends/test_database.py

@@ -9,13 +9,7 @@ from celery.exceptions import ImproperlyConfigured
 from celery.utils import uuid
 from celery.utils import uuid
 
 
 from celery.tests.case import (
 from celery.tests.case import (
-    AppCase,
-    Mock,
-    depends_on_current_app,
-    patch,
-    skip_if_pypy,
-    skip_if_jython,
-    skip_unless_module,
+    AppCase, Mock, depends_on_current_app, patch, skip,
 )
 )
 
 
 try:
 try:
@@ -38,7 +32,7 @@ class SomeClass(object):
         self.data = data
         self.data = data
 
 
 
 
-@skip_unless_module('sqlalchemy')
+@skip.unless_module('sqlalchemy')
 class test_session_cleanup(AppCase):
 class test_session_cleanup(AppCase):
 
 
     def test_context(self):
     def test_context(self):
@@ -56,9 +50,9 @@ class test_session_cleanup(AppCase):
         session.close.assert_called_with()
         session.close.assert_called_with()
 
 
 
 
-@skip_unless_module('sqlalchemy')
-@skip_if_pypy()
-@skip_if_jython()
+@skip.unless_module('sqlalchemy')
+@skip.if_pypy()
+@skip.if_jython()
 class test_DatabaseBackend(AppCase):
 class test_DatabaseBackend(AppCase):
 
 
     def setup(self):
     def setup(self):
@@ -214,7 +208,7 @@ class test_DatabaseBackend(AppCase):
         self.assertIn('foo', repr(TaskSet('foo', None)))
         self.assertIn('foo', repr(TaskSet('foo', None)))
 
 
 
 
-@skip_unless_module('sqlalchemy')
+@skip.unless_module('sqlalchemy')
 class test_SessionManager(AppCase):
 class test_SessionManager(AppCase):
 
 
     def test_after_fork(self):
     def test_after_fork(self):

+ 2 - 2
celery/tests/backends/test_elasticsearch.py

@@ -5,10 +5,10 @@ from celery.backends import elasticsearch as module
 from celery.backends.elasticsearch import ElasticsearchBackend
 from celery.backends.elasticsearch import ElasticsearchBackend
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
 
 
-from celery.tests.case import AppCase, Mock, sentinel, skip_unless_module
+from celery.tests.case import AppCase, Mock, sentinel, skip
 
 
 
 
-@skip_unless_module('elasticsearch')
+@skip.unless_module('elasticsearch')
 class test_ElasticsearchBackend(AppCase):
 class test_ElasticsearchBackend(AppCase):
 
 
     def setup(self):
     def setup(self):

+ 2 - 2
celery/tests/backends/test_filesystem.py

@@ -10,10 +10,10 @@ from celery.backends.filesystem import FilesystemBackend
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
 from celery.utils import uuid
 from celery.utils import uuid
 
 
-from celery.tests.case import AppCase, skip_if_win32
+from celery.tests.case import AppCase, skip
 
 
 
 
-@skip_if_win32()
+@skip.if_win32()
 class test_FilesystemBackend(AppCase):
 class test_FilesystemBackend(AppCase):
 
 
     def setup(self):
     def setup(self):

+ 6 - 7
celery/tests/backends/test_mongodb.py

@@ -12,9 +12,8 @@ from celery.backends import mongodb as module
 from celery.backends.mongodb import InvalidDocument, MongoBackend
 from celery.backends.mongodb import InvalidDocument, MongoBackend
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
 from celery.tests.case import (
 from celery.tests.case import (
-    AppCase, MagicMock, Mock, ANY,
-    depends_on_current_app, override_stdouts, patch, sentinel,
-    skip_unless_module,
+    ANY, AppCase, MagicMock, Mock,
+    mock, depends_on_current_app, patch, sentinel, skip,
 )
 )
 
 
 COLLECTION = 'taskmeta_celery'
 COLLECTION = 'taskmeta_celery'
@@ -28,7 +27,7 @@ MONGODB_COLLECTION = 'collection1'
 MONGODB_GROUP_COLLECTION = 'group_collection1'
 MONGODB_GROUP_COLLECTION = 'group_collection1'
 
 
 
 
-@skip_unless_module('pymongo')
+@skip.unless_module('pymongo')
 class test_MongoBackend(AppCase):
 class test_MongoBackend(AppCase):
 
 
     default_url = 'mongodb://uuuu:pwpw@hostname.dom/database'
     default_url = 'mongodb://uuuu:pwpw@hostname.dom/database'
@@ -407,8 +406,8 @@ class test_MongoBackend(AppCase):
         backend = MongoBackend(app=self.app, url=self.replica_set_url)
         backend = MongoBackend(app=self.app, url=self.replica_set_url)
         self.assertEqual(backend.as_uri(), self.sanitized_replica_set_url)
         self.assertEqual(backend.as_uri(), self.sanitized_replica_set_url)
 
 
-    @override_stdouts
-    def test_regression_worker_startup_info(self):
+    @mock.stdouts
+    def test_regression_worker_startup_info(self, stdout, stderr):
         self.app.conf.result_backend = (
         self.app.conf.result_backend = (
             'mongodb://user:password@host0.com:43437,host1.com:43437'
             'mongodb://user:password@host0.com:43437,host1.com:43437'
             '/work4us?replicaSet=rs&ssl=true'
             '/work4us?replicaSet=rs&ssl=true'
@@ -418,7 +417,7 @@ class test_MongoBackend(AppCase):
         self.assertTrue(worker.startup_info())
         self.assertTrue(worker.startup_info())
 
 
 
 
-@skip_unless_module('pymongo')
+@skip.unless_module('pymongo')
 class test_MongoBackend_no_mock(AppCase):
 class test_MongoBackend_no_mock(AppCase):
 
 
     def test_encode_decode(self):
     def test_encode_decode(self):

+ 4 - 4
celery/tests/backends/test_redis.py

@@ -13,8 +13,8 @@ from celery.datastructures import AttributeDict
 from celery.exceptions import ChordError, ImproperlyConfigured
 from celery.exceptions import ChordError, ImproperlyConfigured
 
 
 from celery.tests.case import (
 from celery.tests.case import (
-    ANY, AppCase, ContextMock, Mock, MockCallbacks,
-    call, depends_on_current_app, patch, skip_unless_module,
+    ANY, AppCase, ContextMock, Mock, mock,
+    call, depends_on_current_app, patch, skip,
 )
 )
 
 
 
 
@@ -58,7 +58,7 @@ class Pipeline(object):
         return [step(*a, **kw) for step, a, kw in self.steps]
         return [step(*a, **kw) for step, a, kw in self.steps]
 
 
 
 
-class Redis(MockCallbacks):
+class Redis(mock.MockCallbacks):
     Connection = Connection
     Connection = Connection
     Pipeline = Pipeline
     Pipeline = Pipeline
 
 
@@ -142,7 +142,7 @@ class test_RedisBackend(AppCase):
         self.b = self.Backend(app=self.app)
         self.b = self.Backend(app=self.app)
 
 
     @depends_on_current_app
     @depends_on_current_app
-    @skip_unless_module('redis')
+    @skip.unless_module('redis')
     def test_reduce(self):
     def test_reduce(self):
         from celery.backends.redis import RedisBackend
         from celery.backends.redis import RedisBackend
         x = RedisBackend(app=self.app)
         x = RedisBackend(app=self.app)

+ 2 - 5
celery/tests/backends/test_riak.py

@@ -5,15 +5,12 @@ from __future__ import absolute_import, unicode_literals
 from celery.backends import riak as module
 from celery.backends import riak as module
 from celery.backends.riak import RiakBackend
 from celery.backends.riak import RiakBackend
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
-from celery.tests.case import (
-    AppCase, MagicMock, Mock, patch, sentinel, skip_unless_module,
-)
-
+from celery.tests.case import AppCase, MagicMock, Mock, patch, sentinel, skip
 
 
 RIAK_BUCKET = 'riak_bucket'
 RIAK_BUCKET = 'riak_bucket'
 
 
 
 
-@skip_unless_module('riak')
+@skip.unless_module('riak')
 class test_RiakBackend(AppCase):
 class test_RiakBackend(AppCase):
 
 
     def setup(self):
     def setup(self):

+ 9 - 9
celery/tests/bin/test_base.py

@@ -12,7 +12,7 @@ from celery.five import module_name_t
 from celery.utils.objects import Bunch
 from celery.utils.objects import Bunch
 
 
 from celery.tests.case import (
 from celery.tests.case import (
-    AppCase, Mock, depends_on_current_app, override_stdouts, patch,
+    AppCase, Mock, depends_on_current_app, mock, patch,
 )
 )
 
 
 
 
@@ -144,14 +144,14 @@ class test_Command(AppCase):
         self.assertDictContainsSubset({'foo': 'bar', 'prog_name': 'foo'},
         self.assertDictContainsSubset({'foo': 'bar', 'prog_name': 'foo'},
                                       kwargs2)
                                       kwargs2)
 
 
-    def test_with_bogus_args(self):
-        with override_stdouts() as (_, stderr):
-            cmd = MockCommand(app=self.app)
-            cmd.supports_args = False
-            with self.assertRaises(SystemExit):
-                cmd.execute_from_commandline(argv=['--bogus'])
-            self.assertTrue(stderr.getvalue())
-            self.assertIn('Unrecognized', stderr.getvalue())
+    @mock.stdouts
+    def test_with_bogus_args(self, _, stderr):
+        cmd = MockCommand(app=self.app)
+        cmd.supports_args = False
+        with self.assertRaises(SystemExit):
+            cmd.execute_from_commandline(argv=['--bogus'])
+        self.assertTrue(stderr.getvalue())
+        self.assertIn('Unrecognized', stderr.getvalue())
 
 
     def test_with_custom_config_module(self):
     def test_with_custom_config_module(self):
         prev = os.environ.pop('CELERY_CONFIG_MODULE', None)
         prev = os.environ.pop('CELERY_CONFIG_MODULE', None)

+ 26 - 24
celery/tests/bin/test_beat.py

@@ -10,8 +10,7 @@ from celery import platforms
 from celery.bin import beat as beat_bin
 from celery.bin import beat as beat_bin
 from celery.apps import beat as beatapp
 from celery.apps import beat as beatapp
 
 
-from celery.tests.case import AppCase, Mock, patch, restore_logging
-from kombu.tests.case import redirect_stdouts
+from celery.tests.case import AppCase, Mock, mock, patch
 
 
 
 
 class MockedShelveModule(object):
 class MockedShelveModule(object):
@@ -113,32 +112,35 @@ class test_Beat(AppCase):
         self.assertTrue(MockService.in_sync)
         self.assertTrue(MockService.in_sync)
         MockService.in_sync = False
         MockService.in_sync = False
 
 
+    @mock.restore_logging()
     def test_setup_logging(self):
     def test_setup_logging(self):
-        with restore_logging():
-            try:
-                # py3k
-                delattr(sys.stdout, 'logger')
-            except AttributeError:
-                pass
-            b = beatapp.Beat(app=self.app, redirect_stdouts=False)
-            b.redirect_stdouts = False
-            b.app.log.already_setup = False
-            b.setup_logging()
-            with self.assertRaises(AttributeError):
-                sys.stdout.logger
-
-    @redirect_stdouts
+        try:
+            # py3k
+            delattr(sys.stdout, 'logger')
+        except AttributeError:
+            pass
+        b = beatapp.Beat(app=self.app, redirect_stdouts=False)
+        b.redirect_stdouts = False
+        b.app.log.already_setup = False
+        b.setup_logging()
+        with self.assertRaises(AttributeError):
+            sys.stdout.logger
+
+    import sys
+    orig_stdout = sys.__stdout__
+
     @patch('celery.apps.beat.logger')
     @patch('celery.apps.beat.logger')
+    @mock.restore_logging()
+    @mock.stdouts
     def test_logs_errors(self, logger, stdout, stderr):
     def test_logs_errors(self, logger, stdout, stderr):
-        with restore_logging():
-            b = MockBeat3(
-                app=self.app, redirect_stdouts=False, socket_timeout=None,
-            )
-            b.start_scheduler()
-            self.assertTrue(logger.critical.called)
-
-    @redirect_stdouts
+        b = MockBeat3(
+            app=self.app, redirect_stdouts=False, socket_timeout=None,
+        )
+        b.start_scheduler()
+        self.assertTrue(logger.critical.called)
+
     @patch('celery.platforms.create_pidlock')
     @patch('celery.platforms.create_pidlock')
+    @mock.stdouts
     def test_use_pidfile(self, create_pidlock, stdout, stderr):
     def test_use_pidfile(self, create_pidlock, stdout, stderr):
         b = MockBeat2(app=self.app, pidfile='pidfilelockfilepid',
         b = MockBeat2(app=self.app, pidfile='pidfilelockfilepid',
                       socket_timeout=None, redirect_stdouts=False)
                       socket_timeout=None, redirect_stdouts=False)

+ 2 - 2
celery/tests/bin/test_celeryd_detach.py

@@ -7,7 +7,7 @@ from celery.bin.celeryd_detach import (
     main,
     main,
 )
 )
 
 
-from celery.tests.case import AppCase, Mock, override_stdouts, patch
+from celery.tests.case import AppCase, Mock, mock, patch
 
 
 
 
 if not IS_WINDOWS:
 if not IS_WINDOWS:
@@ -76,7 +76,7 @@ class test_PartialOptionParser(AppCase):
         ])
         ])
         self.assertEqual(options.pidfile, '/var/pid/foo.pid')
         self.assertEqual(options.pidfile, '/var/pid/foo.pid')
 
 
-        with override_stdouts():
+        with mock.stdouts():
             with self.assertRaises(SystemExit):
             with self.assertRaises(SystemExit):
                 p.parse_args(['--logfile'])
                 p.parse_args(['--logfile'])
             p.get_option('--logfile').nargs = 2
             p.get_option('--logfile').nargs = 2

+ 2 - 2
celery/tests/bin/test_events.py

@@ -2,7 +2,7 @@ from __future__ import absolute_import, unicode_literals
 
 
 from celery.bin import events
 from celery.bin import events
 
 
-from celery.tests.case import AppCase, patch, _old_patch, skip_unless_module
+from celery.tests.case import AppCase, patch, _old_patch, skip
 
 
 
 
 class MockCommand(object):
 class MockCommand(object):
@@ -29,7 +29,7 @@ class test_events(AppCase):
         self.assertEqual(self.ev.run(dump=True), 'me dumper, you?')
         self.assertEqual(self.ev.run(dump=True), 'me dumper, you?')
         self.assertIn('celery events:dump', proctitle.last[0])
         self.assertIn('celery events:dump', proctitle.last[0])
 
 
-    @skip_unless_module('curses')
+    @skip.unless_module('curses', import_errors=(ImportError, OSError))
     def test_run_top(self):
     def test_run_top(self):
         @_old_patch('celery.events.cursesmon', 'evtop',
         @_old_patch('celery.events.cursesmon', 'evtop',
                     lambda **kw: 'me top, you?')
                     lambda **kw: 'me top, you?')

+ 2 - 2
celery/tests/bin/test_multi.py

@@ -17,7 +17,7 @@ from celery.bin.multi import (
 )
 )
 from celery.five import WhateverIO
 from celery.five import WhateverIO
 
 
-from celery.tests.case import AppCase, Mock, patch, skip_unless_symbol
+from celery.tests.case import AppCase, Mock, patch, skip
 
 
 
 
 class test_functions(AppCase):
 class test_functions(AppCase):
@@ -265,7 +265,7 @@ class test_MultiTool(AppCase):
 
 
         )
         )
 
 
-    @skip_unless_symbol('signal.SIGKILL')
+    @skip.unless_symbol('signal.SIGKILL')
     def test_kill(self):
     def test_kill(self):
         self.t.getpids = Mock()
         self.t.getpids = Mock()
         self.t.getpids.return_value = [
         self.t.getpids.return_value = [

+ 49 - 59
celery/tests/bin/test_worker.py

@@ -18,17 +18,7 @@ from celery.exceptions import (
 from celery.platforms import EX_FAILURE, EX_OK
 from celery.platforms import EX_FAILURE, EX_OK
 from celery.worker import state
 from celery.worker import state
 
 
-from celery.tests.case import (
-    AppCase,
-    Mock,
-    override_stdouts,
-    patch,
-    skip_if_jython,
-    skip_if_pypy,
-    skip_if_win32,
-    skip_unless_module,
-    skip_unless_symbol,
-)
+from celery.tests.case import AppCase, Mock, mock, patch, skip
 
 
 
 
 class WorkerAppCase(AppCase):
 class WorkerAppCase(AppCase):
@@ -47,14 +37,14 @@ class Worker(cd.Worker):
 class test_Worker(WorkerAppCase):
 class test_Worker(WorkerAppCase):
     Worker = Worker
     Worker = Worker
 
 
-    @override_stdouts
-    def test_queues_string(self):
+    @mock.stdouts
+    def test_queues_string(self, stdout, stderr):
         w = self.app.Worker()
         w = self.app.Worker()
         w.setup_queues('foo,bar,baz')
         w.setup_queues('foo,bar,baz')
         self.assertIn('foo', self.app.amqp.queues)
         self.assertIn('foo', self.app.amqp.queues)
 
 
-    @override_stdouts
-    def test_cpu_count(self):
+    @mock.stdouts
+    def test_cpu_count(self, stdout, stderr):
         with patch('celery.worker.cpu_count') as cpu_count:
         with patch('celery.worker.cpu_count') as cpu_count:
             cpu_count.side_effect = NotImplementedError()
             cpu_count.side_effect = NotImplementedError()
             w = self.app.Worker(concurrency=None)
             w = self.app.Worker(concurrency=None)
@@ -62,8 +52,8 @@ class test_Worker(WorkerAppCase):
         w = self.app.Worker(concurrency=5)
         w = self.app.Worker(concurrency=5)
         self.assertEqual(w.concurrency, 5)
         self.assertEqual(w.concurrency, 5)
 
 
-    @override_stdouts
-    def test_windows_B_option(self):
+    @mock.stdouts
+    def test_windows_B_option(self, stdout, stderr):
         self.app.IS_WINDOWS = True
         self.app.IS_WINDOWS = True
         with self.assertRaises(SystemExit):
         with self.assertRaises(SystemExit):
             worker(app=self.app).run(beat=True)
             worker(app=self.app).run(beat=True)
@@ -94,8 +84,8 @@ class test_Worker(WorkerAppCase):
                 x.maybe_detach(['--detach'])
                 x.maybe_detach(['--detach'])
             self.assertTrue(detached.called)
             self.assertTrue(detached.called)
 
 
-    @override_stdouts
-    def test_invalid_loglevel_gives_error(self):
+    @mock.stdouts
+    def test_invalid_loglevel_gives_error(self, stdout, stderr):
         x = worker(app=self.app)
         x = worker(app=self.app)
         with self.assertRaises(SystemExit):
         with self.assertRaises(SystemExit):
             x.run(loglevel='GRIM_REAPER')
             x.run(loglevel='GRIM_REAPER')
@@ -118,13 +108,13 @@ class test_Worker(WorkerAppCase):
         worker.loglevel = logging.INFO
         worker.loglevel = logging.INFO
         self.assertTrue(worker.extra_info())
         self.assertTrue(worker.extra_info())
 
 
-    @override_stdouts
-    def test_loglevel_string(self):
+    @mock.stdouts
+    def test_loglevel_string(self, stdout, stderr):
         worker = self.Worker(app=self.app, loglevel='INFO')
         worker = self.Worker(app=self.app, loglevel='INFO')
         self.assertEqual(worker.loglevel, logging.INFO)
         self.assertEqual(worker.loglevel, logging.INFO)
 
 
-    @override_stdouts
-    def test_run_worker(self):
+    @mock.stdouts
+    def test_run_worker(self, stdout, stderr):
         handlers = {}
         handlers = {}
 
 
         class Signals(platforms.Signals):
         class Signals(platforms.Signals):
@@ -151,8 +141,8 @@ class test_Worker(WorkerAppCase):
         finally:
         finally:
             platforms.signals = p
             platforms.signals = p
 
 
-    @override_stdouts
-    def test_startup_info(self):
+    @mock.stdouts
+    def test_startup_info(self, stdout, stderr):
         worker = self.Worker(app=self.app)
         worker = self.Worker(app=self.app)
         worker.on_start()
         worker.on_start()
         self.assertTrue(worker.startup_info())
         self.assertTrue(worker.startup_info())
@@ -189,19 +179,19 @@ class test_Worker(WorkerAppCase):
         finally:
         finally:
             cd.ARTLINES = prev
             cd.ARTLINES = prev
 
 
-    @override_stdouts
-    def test_run(self):
+    @mock.stdouts
+    def test_run(self, stdout, stderr):
         self.Worker(app=self.app).on_start()
         self.Worker(app=self.app).on_start()
         self.Worker(app=self.app, purge=True).on_start()
         self.Worker(app=self.app, purge=True).on_start()
         worker = self.Worker(app=self.app)
         worker = self.Worker(app=self.app)
         worker.on_start()
         worker.on_start()
 
 
-    @override_stdouts
-    def test_purge_messages(self):
+    @mock.stdouts
+    def test_purge_messages(self, stdout, stderr):
         self.Worker(app=self.app).purge_messages()
         self.Worker(app=self.app).purge_messages()
 
 
-    @override_stdouts
-    def test_init_queues(self):
+    @mock.stdouts
+    def test_init_queues(self, stdout, stderr):
         app = self.app
         app = self.app
         c = app.conf
         c = app.conf
         app.amqp.queues = app.amqp.Queues({
         app.amqp.queues = app.amqp.Queues({
@@ -231,8 +221,8 @@ class test_Worker(WorkerAppCase):
             app.amqp.queues['image'],
             app.amqp.queues['image'],
         )
         )
 
 
-    @override_stdouts
-    def test_autoscale_argument(self):
+    @mock.stdouts
+    def test_autoscale_argument(self, stdout, stderr):
         worker1 = self.Worker(app=self.app, autoscale='10,3')
         worker1 = self.Worker(app=self.app, autoscale='10,3')
         self.assertListEqual(worker1.autoscale, [10, 3])
         self.assertListEqual(worker1.autoscale, [10, 3])
         worker2 = self.Worker(app=self.app, autoscale='10')
         worker2 = self.Worker(app=self.app, autoscale='10')
@@ -247,17 +237,17 @@ class test_Worker(WorkerAppCase):
         self.assertListEqual(worker2.include, ['os', 'sys'])
         self.assertListEqual(worker2.include, ['os', 'sys'])
         self.Worker(app=self.app, include=['os', 'sys'])
         self.Worker(app=self.app, include=['os', 'sys'])
 
 
-    @override_stdouts
-    def test_unknown_loglevel(self):
+    @mock.stdouts
+    def test_unknown_loglevel(self, stdout, stderr):
         with self.assertRaises(SystemExit):
         with self.assertRaises(SystemExit):
             worker(app=self.app).run(loglevel='ALIEN')
             worker(app=self.app).run(loglevel='ALIEN')
         worker1 = self.Worker(app=self.app, loglevel=0xFFFF)
         worker1 = self.Worker(app=self.app, loglevel=0xFFFF)
         self.assertEqual(worker1.loglevel, 0xFFFF)
         self.assertEqual(worker1.loglevel, 0xFFFF)
 
 
-    @skip_if_win32
-    @override_stdouts
     @patch('os._exit')
     @patch('os._exit')
-    def test_warns_if_running_as_privileged_user(self, _exit):
+    @skip.if_win32()
+    @mock.stdouts
+    def test_warns_if_running_as_privileged_user(self, _exit, stdout, stderr):
         with patch('os.getuid') as getuid:
         with patch('os.getuid') as getuid:
             getuid.return_value = 0
             getuid.return_value = 0
             self.app.conf.accept_content = ['pickle']
             self.app.conf.accept_content = ['pickle']
@@ -281,14 +271,14 @@ class test_Worker(WorkerAppCase):
                 worker = self.Worker(app=self.app)
                 worker = self.Worker(app=self.app)
                 worker.on_start()
                 worker.on_start()
 
 
-    @override_stdouts
-    def test_redirect_stdouts(self):
+    @mock.stdouts
+    def test_redirect_stdouts(self, stdout, stderr):
         self.Worker(app=self.app, redirect_stdouts=False)
         self.Worker(app=self.app, redirect_stdouts=False)
         with self.assertRaises(AttributeError):
         with self.assertRaises(AttributeError):
             sys.stdout.logger
             sys.stdout.logger
 
 
-    @override_stdouts
-    def test_on_start_custom_logging(self):
+    @mock.stdouts
+    def test_on_start_custom_logging(self, stdout, stderr):
         self.app.log.redirect_stdouts = Mock()
         self.app.log.redirect_stdouts = Mock()
         worker = self.Worker(app=self.app, redirect_stoutds=True)
         worker = self.Worker(app=self.app, redirect_stoutds=True)
         worker._custom_logging = True
         worker._custom_logging = True
@@ -306,8 +296,8 @@ class test_Worker(WorkerAppCase):
         finally:
         finally:
             self.app.log.setup = prev
             self.app.log.setup = prev
 
 
-    @override_stdouts
-    def test_startup_info_pool_is_str(self):
+    @mock.stdouts
+    def test_startup_info_pool_is_str(self, stdout, stderr):
         worker = self.Worker(app=self.app, redirect_stdouts=False)
         worker = self.Worker(app=self.app, redirect_stdouts=False)
         worker.pool_cls = 'foo'
         worker.pool_cls = 'foo'
         worker.startup_info()
         worker.startup_info()
@@ -329,8 +319,8 @@ class test_Worker(WorkerAppCase):
         finally:
         finally:
             signals.setup_logging.disconnect(on_logging_setup)
             signals.setup_logging.disconnect(on_logging_setup)
 
 
-    @override_stdouts
-    def test_platform_tweaks_osx(self):
+    @mock.stdouts
+    def test_platform_tweaks_osx(self, stdout, stderr):
 
 
         class OSXWorker(Worker):
         class OSXWorker(Worker):
             proxy_workaround_installed = False
             proxy_workaround_installed = False
@@ -357,8 +347,8 @@ class test_Worker(WorkerAppCase):
         finally:
         finally:
             cd.install_HUP_not_supported_handler = prev
             cd.install_HUP_not_supported_handler = prev
 
 
-    @override_stdouts
-    def test_general_platform_tweaks(self):
+    @mock.stdouts
+    def test_general_platform_tweaks(self, stdout, stderr):
 
 
         restart_worker_handler_installed = [False]
         restart_worker_handler_installed = [False]
 
 
@@ -378,8 +368,8 @@ class test_Worker(WorkerAppCase):
         finally:
         finally:
             cd.install_worker_restart_handler = prev
             cd.install_worker_restart_handler = prev
 
 
-    @override_stdouts
-    def test_on_consumer_ready(self):
+    @mock.stdouts
+    def test_on_consumer_ready(self, stdout, stderr):
         worker_ready_sent = [False]
         worker_ready_sent = [False]
 
 
         @signals.worker_ready.connect
         @signals.worker_ready.connect
@@ -390,13 +380,13 @@ class test_Worker(WorkerAppCase):
         self.assertTrue(worker_ready_sent[0])
         self.assertTrue(worker_ready_sent[0])
 
 
 
 
-@override_stdouts
+@mock.stdouts
 class test_funs(WorkerAppCase):
 class test_funs(WorkerAppCase):
 
 
     def test_active_thread_count(self):
     def test_active_thread_count(self):
         self.assertTrue(cd.active_thread_count())
         self.assertTrue(cd.active_thread_count())
 
 
-    @skip_unless_module('setproctitle')
+    @skip.unless_module('setproctitle')
     def test_set_process_status(self):
     def test_set_process_status(self):
         worker = Worker(app=self.app, hostname='xyzza')
         worker = Worker(app=self.app, hostname='xyzza')
         prev1, sys.argv = sys.argv, ['Arg0']
         prev1, sys.argv = sys.argv, ['Arg0']
@@ -435,7 +425,7 @@ class test_funs(WorkerAppCase):
             sys.argv = s
             sys.argv = s
 
 
 
 
-@override_stdouts
+@mock.stdouts
 class test_signal_handlers(WorkerAppCase):
 class test_signal_handlers(WorkerAppCase):
 
 
     class _Worker(object):
     class _Worker(object):
@@ -504,7 +494,7 @@ class test_signal_handlers(WorkerAppCase):
             with self.assertRaises(WorkerTerminate):
             with self.assertRaises(WorkerTerminate):
                 next_handlers['SIGINT']('SIGINT', object())
                 next_handlers['SIGINT']('SIGINT', object())
 
 
-    @skip_unless_module('multiprocessing')
+    @skip.unless_module('multiprocessing')
     def test_worker_int_handler_only_stop_MainProcess(self):
     def test_worker_int_handler_only_stop_MainProcess(self):
         process = current_process()
         process = current_process()
         name, process.name = process.name, 'OtherProcess'
         name, process.name = process.name, 'OtherProcess'
@@ -535,7 +525,7 @@ class test_signal_handlers(WorkerAppCase):
         handlers = self.psig(cd.install_HUP_not_supported_handler, worker)
         handlers = self.psig(cd.install_HUP_not_supported_handler, worker)
         handlers['SIGHUP']('SIGHUP', object())
         handlers['SIGHUP']('SIGHUP', object())
 
 
-    @skip_unless_module('multiprocessing')
+    @skip.unless_module('multiprocessing')
     def test_worker_term_hard_handler_only_stop_MainProcess(self):
     def test_worker_term_hard_handler_only_stop_MainProcess(self):
         process = current_process()
         process = current_process()
         name, process.name = process.name, 'OtherProcess'
         name, process.name = process.name, 'OtherProcess'
@@ -586,14 +576,14 @@ class test_signal_handlers(WorkerAppCase):
                 state.should_stop = None
                 state.should_stop = None
 
 
     @patch('sys.__stderr__')
     @patch('sys.__stderr__')
-    @skip_if_pypy()
-    @skip_if_jython()
+    @skip.if_pypy()
+    @skip.if_jython()
     def test_worker_cry_handler(self, stderr):
     def test_worker_cry_handler(self, stderr):
         handlers = self.psig(cd.install_cry_handler)
         handlers = self.psig(cd.install_cry_handler)
         self.assertIsNone(handlers['SIGUSR1']('SIGUSR1', object()))
         self.assertIsNone(handlers['SIGUSR1']('SIGUSR1', object()))
         self.assertTrue(stderr.write.called)
         self.assertTrue(stderr.write.called)
 
 
-    @skip_unless_module('multiprocessing')
+    @skip.unless_module('multiprocessing')
     def test_worker_term_handler_only_stop_MainProcess(self):
     def test_worker_term_handler_only_stop_MainProcess(self):
         process = current_process()
         process = current_process()
         name, process.name = process.name, 'OtherProcess'
         name, process.name = process.name, 'OtherProcess'
@@ -614,7 +604,7 @@ class test_signal_handlers(WorkerAppCase):
             process.name = name
             process.name = name
             state.should_stop = None
             state.should_stop = None
 
 
-    @skip_unless_symbol('os.execv')
+    @skip.unless_symbol('os.execv')
     @patch('celery.platforms.close_open_fds')
     @patch('celery.platforms.close_open_fds')
     @patch('atexit.register')
     @patch('atexit.register')
     @patch('os.close')
     @patch('os.close')

+ 9 - 5
celery/tests/case.py

@@ -8,7 +8,6 @@ import os
 import sys
 import sys
 import threading
 import threading
 
 
-from contextlib import contextmanager
 from copy import deepcopy
 from copy import deepcopy
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
 from functools import partial, wraps
 from functools import partial, wraps
@@ -22,10 +21,16 @@ from celery.backends.cache import CacheBackend, DummyClient
 from celery.exceptions import CDeprecationWarning, CPendingDeprecationWarning
 from celery.exceptions import CDeprecationWarning, CPendingDeprecationWarning
 from celery.utils.imports import qualname
 from celery.utils.imports import qualname
 
 
-from ._case import *  # noqa
-from ._case import __all__ as _case_all, Case as _Case, decorator
+from case import (
+    ANY, ContextMock, MagicMock, Mock, call, mock, skip, patch, sentinel,
+)
+from case import Case as _Case
+from case.utils import decorator
+
+__all__ = [
+    'ANY', 'ContextMock', 'MagicMock', 'Mock',
+    'call', 'mock', 'skip', 'patch', 'sentinel',
 
 
-__all__ = _case_all + [
     'AppCase', 'TaskMessage', 'TaskMessage1',
     'AppCase', 'TaskMessage', 'TaskMessage1',
     'depends_on_current_app', 'assert_signal_called', 'task_message_from_sig',
     'depends_on_current_app', 'assert_signal_called', 'task_message_from_sig',
 ]
 ]
@@ -246,7 +251,6 @@ class AppCase(Case):
 
 
 
 
 @decorator
 @decorator
-@contextmanager
 def assert_signal_called(signal, **expected):
 def assert_signal_called(signal, **expected):
     handler = Mock()
     handler = Mock()
     call_handler = partial(handler)
     call_handler = partial(handler)

+ 2 - 2
celery/tests/concurrency/test_eventlet.py

@@ -9,10 +9,10 @@ from celery.concurrency.eventlet import (
     TaskPool,
     TaskPool,
 )
 )
 
 
-from celery.tests.case import AppCase, Mock, patch, skip_if_pypy
+from celery.tests.case import AppCase, Mock, patch, skip
 
 
 
 
-@skip_if_pypy()
+@skip.if_pypy()
 class EventletCase(AppCase):
 class EventletCase(AppCase):
 
 
     def setup(self):
     def setup(self):

+ 2 - 2
celery/tests/concurrency/test_gevent.py

@@ -6,7 +6,7 @@ from celery.concurrency.gevent import (
     apply_timeout,
     apply_timeout,
 )
 )
 
 
-from celery.tests.case import AppCase, Mock, patch, skip_if_pypy
+from celery.tests.case import AppCase, Mock, patch, skip
 
 
 gevent_modules = (
 gevent_modules = (
     'gevent',
     'gevent',
@@ -17,7 +17,7 @@ gevent_modules = (
 )
 )
 
 
 
 
-@skip_if_pypy()
+@skip.if_pypy()
 class GeventCase(AppCase):
 class GeventCase(AppCase):
 
 
     def setup(self):
     def setup(self):

+ 2 - 2
celery/tests/concurrency/test_pool.py

@@ -5,7 +5,7 @@ import itertools
 
 
 from billiard.einfo import ExceptionInfo
 from billiard.einfo import ExceptionInfo
 
 
-from celery.tests.case import AppCase, skip_unless_module
+from celery.tests.case import AppCase, skip
 
 
 
 
 def do_something(i):
 def do_something(i):
@@ -23,7 +23,7 @@ def raise_something(i):
         return ExceptionInfo()
         return ExceptionInfo()
 
 
 
 
-@skip_unless_module('multiprocessing')
+@skip.unless_module('multiprocessing')
 class test_TaskPool(AppCase):
 class test_TaskPool(AppCase):
 
 
     def setup(self):
     def setup(self):

+ 5 - 7
celery/tests/concurrency/test_prefork.py

@@ -12,9 +12,7 @@ from celery.five import range
 from celery.utils.functional import noop
 from celery.utils.functional import noop
 from celery.utils.objects import Bunch
 from celery.utils.objects import Bunch
 
 
-from celery.tests.case import (
-    AppCase, Mock, patch, restore_logging, skip_if_win32, skip_unless_module,
-)
+from celery.tests.case import AppCase, Mock, mock, patch, skip
 
 
 try:
 try:
     from celery.concurrency import prefork as mp
     from celery.concurrency import prefork as mp
@@ -60,7 +58,7 @@ class test_process_initializer(AppCase):
     @patch('celery.platforms.signals')
     @patch('celery.platforms.signals')
     @patch('celery.platforms.set_mp_process_title')
     @patch('celery.platforms.set_mp_process_title')
     def test_process_initializer(self, set_mp_process_title, _signals):
     def test_process_initializer(self, set_mp_process_title, _signals):
-        with restore_logging():
+        with mock.restore_logging():
             from celery import signals
             from celery import signals
             from celery._state import _tls
             from celery._state import _tls
             from celery.concurrency.prefork import (
             from celery.concurrency.prefork import (
@@ -186,12 +184,12 @@ class ExeMockTaskPool(mp.TaskPool):
     Pool = BlockingPool = ExeMockPool
     Pool = BlockingPool = ExeMockPool
 
 
 
 
-@skip_unless_module('multiprocessing')
+@skip.unless_module('multiprocessing')
 class PoolCase(AppCase):
 class PoolCase(AppCase):
     pass
     pass
 
 
 
 
-@skip_if_win32
+@skip.if_win32()
 class test_AsynPool(PoolCase):
 class test_AsynPool(PoolCase):
 
 
     def test_gen_not_started(self):
     def test_gen_not_started(self):
@@ -297,7 +295,7 @@ class test_AsynPool(PoolCase):
         w.outq.put.assert_called_with((asynpool.WORKER_UP, (1234,)))
         w.outq.put.assert_called_with((asynpool.WORKER_UP, (1234,)))
 
 
 
 
-@skip_if_win32
+@skip.if_win32()
 class test_ResultHandler(PoolCase):
 class test_ResultHandler(PoolCase):
 
 
     def test_process_result(self):
     def test_process_result(self):

+ 6 - 6
celery/tests/concurrency/test_threads.py

@@ -2,7 +2,7 @@ from __future__ import absolute_import, unicode_literals
 
 
 from celery.concurrency.threads import NullDict, TaskPool, apply_target
 from celery.concurrency.threads import NullDict, TaskPool, apply_target
 
 
-from celery.tests.case import AppCase, Case, Mock, mask_modules, mock_module
+from celery.tests.case import AppCase, Case, Mock, mock
 
 
 
 
 class test_NullDict(Case):
 class test_NullDict(Case):
@@ -18,32 +18,32 @@ class test_TaskPool(AppCase):
 
 
     def test_without_threadpool(self):
     def test_without_threadpool(self):
 
 
-        with mask_modules('threadpool'):
+        with mock.mask_modules('threadpool'):
             with self.assertRaises(ImportError):
             with self.assertRaises(ImportError):
                 TaskPool(app=self.app)
                 TaskPool(app=self.app)
 
 
     def test_with_threadpool(self):
     def test_with_threadpool(self):
-        with mock_module('threadpool'):
+        with mock.module('threadpool'):
             x = TaskPool(app=self.app)
             x = TaskPool(app=self.app)
             self.assertTrue(x.ThreadPool)
             self.assertTrue(x.ThreadPool)
             self.assertTrue(x.WorkRequest)
             self.assertTrue(x.WorkRequest)
 
 
     def test_on_start(self):
     def test_on_start(self):
-        with mock_module('threadpool'):
+        with mock.module('threadpool'):
             x = TaskPool(app=self.app)
             x = TaskPool(app=self.app)
             x.on_start()
             x.on_start()
             self.assertTrue(x._pool)
             self.assertTrue(x._pool)
             self.assertIsInstance(x._pool.workRequests, NullDict)
             self.assertIsInstance(x._pool.workRequests, NullDict)
 
 
     def test_on_stop(self):
     def test_on_stop(self):
-        with mock_module('threadpool'):
+        with mock.module('threadpool'):
             x = TaskPool(app=self.app)
             x = TaskPool(app=self.app)
             x.on_start()
             x.on_start()
             x.on_stop()
             x.on_stop()
             x._pool.dismissWorkers.assert_called_with(x.limit, do_join=True)
             x._pool.dismissWorkers.assert_called_with(x.limit, do_join=True)
 
 
     def test_on_apply(self):
     def test_on_apply(self):
-        with mock_module('threadpool'):
+        with mock.module('threadpool'):
             x = TaskPool(app=self.app)
             x = TaskPool(app=self.app)
             x.on_start()
             x.on_start()
             callback = Mock()
             callback = Mock()

+ 5 - 5
celery/tests/contrib/test_migrate.py

@@ -26,7 +26,7 @@ from celery.contrib.migrate import (
     move,
     move,
 )
 )
 from celery.utils.encoding import bytes_t, ensure_bytes
 from celery.utils.encoding import bytes_t, ensure_bytes
-from celery.tests.case import AppCase, Mock, override_stdouts, patch
+from celery.tests.case import AppCase, Mock, mock, patch
 
 
 # hack to ignore error at shutdown
 # hack to ignore error at shutdown
 QoS.restore_at_shutdown = False
 QoS.restore_at_shutdown = False
@@ -213,10 +213,10 @@ class test_utils(AppCase):
         self.assertEqual(_maybe_queue(app, 'foo'), 313)
         self.assertEqual(_maybe_queue(app, 'foo'), 313)
         self.assertEqual(_maybe_queue(app, Queue('foo')), Queue('foo'))
         self.assertEqual(_maybe_queue(app, Queue('foo')), Queue('foo'))
 
 
-    def test_filter_status(self):
-        with override_stdouts() as (stdout, stderr):
-            filter_status(State(), {'id': '1', 'task': 'add'}, Mock())
-            self.assertTrue(stdout.getvalue())
+    @mock.stdouts
+    def test_filter_status(self, stdout, stderr):
+        filter_status(State(), {'id': '1', 'task': 'add'}, Mock())
+        self.assertTrue(stdout.getvalue())
 
 
     def test_move_by_taskmap(self):
     def test_move_by_taskmap(self):
         with patch('celery.contrib.migrate.move') as move:
         with patch('celery.contrib.migrate.move') as move:

+ 3 - 3
celery/tests/contrib/test_rdb.py

@@ -9,7 +9,7 @@ from celery.contrib.rdb import (
     set_trace,
     set_trace,
 )
 )
 from celery.five import WhateverIO
 from celery.five import WhateverIO
-from celery.tests.case import AppCase, Mock, patch, skip_if_pypy
+from celery.tests.case import AppCase, Mock, patch, skip
 
 
 
 
 class SockErr(socket.error):
 class SockErr(socket.error):
@@ -32,7 +32,7 @@ class test_Rdb(AppCase):
         self.assertTrue(debugger.return_value.set_trace.called)
         self.assertTrue(debugger.return_value.set_trace.called)
 
 
     @patch('celery.contrib.rdb.Rdb.get_avail_port')
     @patch('celery.contrib.rdb.Rdb.get_avail_port')
-    @skip_if_pypy()
+    @skip.if_pypy()
     def test_rdb(self, get_avail_port):
     def test_rdb(self, get_avail_port):
         sock = Mock()
         sock = Mock()
         get_avail_port.return_value = (sock, 8000)
         get_avail_port.return_value = (sock, 8000)
@@ -76,7 +76,7 @@ class test_Rdb(AppCase):
             rdb.set_quit.assert_called_with()
             rdb.set_quit.assert_called_with()
 
 
     @patch('socket.socket')
     @patch('socket.socket')
-    @skip_if_pypy()
+    @skip.if_pypy()
     def test_get_avail_port(self, sock):
     def test_get_avail_port(self, sock):
         out = WhateverIO()
         out = WhateverIO()
         sock.return_value.accept.return_value = (Mock(), ['helu'])
         sock.return_value.accept.return_value = (Mock(), ['helu'])

+ 2 - 2
celery/tests/events/test_cursesmon.py

@@ -1,6 +1,6 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
-from celery.tests.case import AppCase, skip_unless_module
+from celery.tests.case import AppCase, skip
 
 
 
 
 class MockWindow(object):
 class MockWindow(object):
@@ -9,7 +9,7 @@ class MockWindow(object):
         return self.y, self.x
         return self.y, self.x
 
 
 
 
-@skip_unless_module('curses')
+@skip.unless_module('curses', import_errors=(ImportError, OSError))
 class test_CursesDisplay(AppCase):
 class test_CursesDisplay(AppCase):
 
 
     def setup(self):
     def setup(self):

+ 11 - 10
celery/tests/events/test_snapshot.py

@@ -2,7 +2,8 @@ from __future__ import absolute_import, unicode_literals
 
 
 from celery.events import Events
 from celery.events import Events
 from celery.events.snapshot import Polaroid, evcam
 from celery.events.snapshot import Polaroid, evcam
-from celery.tests.case import AppCase, patch, restore_logging
+
+from celery.tests.case import AppCase, mock, patch
 
 
 
 
 class TRef(object):
 class TRef(object):
@@ -113,16 +114,16 @@ class test_evcam(AppCase):
         self.app.events = self.MockEvents()
         self.app.events = self.MockEvents()
         self.app.events.app = self.app
         self.app.events.app = self.app
 
 
+    @mock.restore_logging()
     def test_evcam(self):
     def test_evcam(self):
-        with restore_logging():
-            evcam(Polaroid, timer=timer, app=self.app)
-            evcam(Polaroid, timer=timer, loglevel='CRITICAL', app=self.app)
-            self.MockReceiver.raise_keyboard_interrupt = True
-            try:
-                with self.assertRaises(SystemExit):
-                    evcam(Polaroid, timer=timer, app=self.app)
-            finally:
-                self.MockReceiver.raise_keyboard_interrupt = False
+        evcam(Polaroid, timer=timer, app=self.app)
+        evcam(Polaroid, timer=timer, loglevel='CRITICAL', app=self.app)
+        self.MockReceiver.raise_keyboard_interrupt = True
+        try:
+            with self.assertRaises(SystemExit):
+                evcam(Polaroid, timer=timer, app=self.app)
+        finally:
+            self.MockReceiver.raise_keyboard_interrupt = False
 
 
     @patch('celery.platforms.create_pidlock')
     @patch('celery.platforms.create_pidlock')
     def test_evcam_pidfile(self, create_pidlock):
     def test_evcam_pidfile(self, create_pidlock):

+ 2 - 2
celery/tests/events/test_state.py

@@ -19,7 +19,7 @@ from celery.events.state import (
 )
 )
 from celery.five import range
 from celery.five import range
 from celery.utils import uuid
 from celery.utils import uuid
-from celery.tests.case import AppCase, Mock, patch, todo
+from celery.tests.case import AppCase, Mock, patch, skip
 
 
 try:
 try:
     Decimal(2.6)
     Decimal(2.6)
@@ -374,7 +374,7 @@ class test_State(AppCase):
         self.assertEqual(now[1][0], tC)
         self.assertEqual(now[1][0], tC)
         self.assertEqual(now[2][0], tB)
         self.assertEqual(now[2][0], tB)
 
 
-    @todo(reason='not working')
+    @skip.todo(reason='not working')
     def test_task_descending_clock_ordering(self):
     def test_task_descending_clock_ordering(self):
         state = State()
         state = State()
         r = ev_logical_clock_ordering(state)
         r = ev_logical_clock_ordering(state)

+ 11 - 13
celery/tests/fixups/test_django.py

@@ -12,9 +12,7 @@ from celery.fixups.django import (
 )
 )
 from celery.utils.objects import Bunch
 from celery.utils.objects import Bunch
 
 
-from celery.tests.case import (
-    AppCase, Mock, patch, patch_modules, mask_modules,
-)
+from celery.tests.case import AppCase, Mock, mock, patch
 
 
 
 
 class FixupCase(AppCase):
 class FixupCase(AppCase):
@@ -78,11 +76,11 @@ class test_DjangoFixup(FixupCase):
                 fixup(self.app)
                 fixup(self.app)
                 self.assertFalse(Fixup.called)
                 self.assertFalse(Fixup.called)
             with patch.dict(os.environ, DJANGO_SETTINGS_MODULE='settings'):
             with patch.dict(os.environ, DJANGO_SETTINGS_MODULE='settings'):
-                with mask_modules('django'):
+                with mock.mask_modules('django'):
                     with self.assertWarnsRegex(UserWarning, 'but Django is'):
                     with self.assertWarnsRegex(UserWarning, 'but Django is'):
                         fixup(self.app)
                         fixup(self.app)
                         self.assertFalse(Fixup.called)
                         self.assertFalse(Fixup.called)
-                with patch_modules('django'):
+                with mock.patch_modules('django'):
                     fixup(self.app)
                     fixup(self.app)
                     self.assertTrue(Fixup.called)
                     self.assertTrue(Fixup.called)
 
 
@@ -334,7 +332,7 @@ class test_DjangoWorkerFixup(FixupCase):
         django.setup.assert_called_with()
         django.setup.assert_called_with()
 
 
     def test_mysql_errors(self):
     def test_mysql_errors(self):
-        with patch_modules('MySQLdb'):
+        with mock.patch_modules('MySQLdb'):
             import MySQLdb as mod
             import MySQLdb as mod
             mod.DatabaseError = Mock()
             mod.DatabaseError = Mock()
             mod.InterfaceError = Mock()
             mod.InterfaceError = Mock()
@@ -343,12 +341,12 @@ class test_DjangoWorkerFixup(FixupCase):
                 self.assertIn(mod.DatabaseError, f.database_errors)
                 self.assertIn(mod.DatabaseError, f.database_errors)
                 self.assertIn(mod.InterfaceError, f.database_errors)
                 self.assertIn(mod.InterfaceError, f.database_errors)
                 self.assertIn(mod.OperationalError, f.database_errors)
                 self.assertIn(mod.OperationalError, f.database_errors)
-        with mask_modules('MySQLdb'):
+        with mock.mask_modules('MySQLdb'):
             with self.fixup_context(self.app):
             with self.fixup_context(self.app):
                 pass
                 pass
 
 
     def test_pg_errors(self):
     def test_pg_errors(self):
-        with patch_modules('psycopg2'):
+        with mock.patch_modules('psycopg2'):
             import psycopg2 as mod
             import psycopg2 as mod
             mod.DatabaseError = Mock()
             mod.DatabaseError = Mock()
             mod.InterfaceError = Mock()
             mod.InterfaceError = Mock()
@@ -357,12 +355,12 @@ class test_DjangoWorkerFixup(FixupCase):
                 self.assertIn(mod.DatabaseError, f.database_errors)
                 self.assertIn(mod.DatabaseError, f.database_errors)
                 self.assertIn(mod.InterfaceError, f.database_errors)
                 self.assertIn(mod.InterfaceError, f.database_errors)
                 self.assertIn(mod.OperationalError, f.database_errors)
                 self.assertIn(mod.OperationalError, f.database_errors)
-        with mask_modules('psycopg2'):
+        with mock.mask_modules('psycopg2'):
             with self.fixup_context(self.app):
             with self.fixup_context(self.app):
                 pass
                 pass
 
 
     def test_sqlite_errors(self):
     def test_sqlite_errors(self):
-        with patch_modules('sqlite3'):
+        with mock.patch_modules('sqlite3'):
             import sqlite3 as mod
             import sqlite3 as mod
             mod.DatabaseError = Mock()
             mod.DatabaseError = Mock()
             mod.InterfaceError = Mock()
             mod.InterfaceError = Mock()
@@ -371,12 +369,12 @@ class test_DjangoWorkerFixup(FixupCase):
                 self.assertIn(mod.DatabaseError, f.database_errors)
                 self.assertIn(mod.DatabaseError, f.database_errors)
                 self.assertIn(mod.InterfaceError, f.database_errors)
                 self.assertIn(mod.InterfaceError, f.database_errors)
                 self.assertIn(mod.OperationalError, f.database_errors)
                 self.assertIn(mod.OperationalError, f.database_errors)
-        with mask_modules('sqlite3'):
+        with mock.mask_modules('sqlite3'):
             with self.fixup_context(self.app):
             with self.fixup_context(self.app):
                 pass
                 pass
 
 
     def test_oracle_errors(self):
     def test_oracle_errors(self):
-        with patch_modules('cx_Oracle'):
+        with mock.patch_modules('cx_Oracle'):
             import cx_Oracle as mod
             import cx_Oracle as mod
             mod.DatabaseError = Mock()
             mod.DatabaseError = Mock()
             mod.InterfaceError = Mock()
             mod.InterfaceError = Mock()
@@ -385,6 +383,6 @@ class test_DjangoWorkerFixup(FixupCase):
                 self.assertIn(mod.DatabaseError, f.database_errors)
                 self.assertIn(mod.DatabaseError, f.database_errors)
                 self.assertIn(mod.InterfaceError, f.database_errors)
                 self.assertIn(mod.InterfaceError, f.database_errors)
                 self.assertIn(mod.OperationalError, f.database_errors)
                 self.assertIn(mod.OperationalError, f.database_errors)
-        with mask_modules('cx_Oracle'):
+        with mock.mask_modules('cx_Oracle'):
             with self.fixup_context(self.app):
             with self.fixup_context(self.app):
                 pass
                 pass

+ 2 - 2
celery/tests/security/case.py

@@ -1,8 +1,8 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
-from celery.tests.case import AppCase, skip_unless_module
+from celery.tests.case import AppCase, skip
 
 
 
 
-@skip_unless_module('OpenSSL.crypto', name='pyOpenSSL')
+@skip.unless_module('OpenSSL.crypto', name='pyOpenSSL')
 class SecurityCase(AppCase):
 class SecurityCase(AppCase):
     pass
     pass

+ 3 - 3
celery/tests/security/test_certificate.py

@@ -6,7 +6,7 @@ from celery.security.certificate import Certificate, CertStore, FSCertStore
 from . import CERT1, CERT2, KEY1
 from . import CERT1, CERT2, KEY1
 from .case import SecurityCase
 from .case import SecurityCase
 
 
-from celery.tests.case import Mock, mock_open, patch, todo
+from celery.tests.case import Mock, mock, patch, skip
 
 
 
 
 class test_Certificate(SecurityCase):
 class test_Certificate(SecurityCase):
@@ -27,7 +27,7 @@ class test_Certificate(SecurityCase):
         with self.assertRaises(SecurityError):
         with self.assertRaises(SecurityError):
             Certificate(KEY1)
             Certificate(KEY1)
 
 
-    @todo(reason='cert expired')
+    @skip.todo(reason='cert expired')
     def test_has_expired(self):
     def test_has_expired(self):
         self.assertFalse(Certificate(CERT1).has_expired())
         self.assertFalse(Certificate(CERT1).has_expired())
 
 
@@ -68,7 +68,7 @@ class test_FSCertStore(SecurityCase):
         cert.has_expired.return_value = False
         cert.has_expired.return_value = False
         isdir.return_value = True
         isdir.return_value = True
         glob.return_value = ['foo.cert']
         glob.return_value = ['foo.cert']
-        with mock_open():
+        with mock.open():
             cert.get_id.return_value = 1
             cert.get_id.return_value = 1
             x = FSCertStore('/var/certs')
             x = FSCertStore('/var/certs')
             self.assertIn(1, x._certs)
             self.assertIn(1, x._certs)

+ 2 - 2
celery/tests/security/test_security.py

@@ -26,7 +26,7 @@ from kombu.serialization import registry
 
 
 from .case import SecurityCase
 from .case import SecurityCase
 
 
-from celery.tests.case import Mock, mock_open, patch
+from celery.tests.case import Mock, mock, patch
 
 
 
 
 class test_security(SecurityCase):
 class test_security(SecurityCase):
@@ -86,7 +86,7 @@ class test_security(SecurityCase):
                 calls[0] += 1
                 calls[0] += 1
 
 
         self.app.conf.task_serializer = 'auth'
         self.app.conf.task_serializer = 'auth'
-        with mock_open(side_effect=effect):
+        with mock.open(side_effect=effect):
             with patch('celery.security.registry') as registry:
             with patch('celery.security.registry') as registry:
                 store = Mock()
                 store = Mock()
                 self.app.setup_security(['json'], key, cert, store)
                 self.app.setup_security(['json'], key, cert, store)

+ 0 - 0
celery/tests/slow/__init__.py


+ 2 - 2
celery/tests/utils/test_datastructures.py

@@ -18,7 +18,7 @@ from celery.datastructures import (
 from celery.five import WhateverIO, items
 from celery.five import WhateverIO, items
 from celery.utils.objects import Bunch
 from celery.utils.objects import Bunch
 
 
-from celery.tests.case import Case, Mock, skip_if_win32
+from celery.tests.case import Case, Mock, skip
 
 
 
 
 class test_DictAttribute(Case):
 class test_DictAttribute(Case):
@@ -167,7 +167,7 @@ class test_ExceptionInfo(Case):
             self.assertTrue(r)
             self.assertTrue(r)
 
 
 
 
-@skip_if_win32()
+@skip.if_win32()
 class test_LimitedSet(Case):
 class test_LimitedSet(Case):
 
 
     def test_add(self):
     def test_add(self):

+ 21 - 25
celery/tests/utils/test_platforms.py

@@ -38,11 +38,7 @@ try:
 except ImportError:  # pragma: no cover
 except ImportError:  # pragma: no cover
     resource = None  # noqa
     resource = None  # noqa
 
 
-from celery.tests._case import open_fqdn
-from celery.tests.case import (
-    Case, Mock,
-    call, override_stdouts, mock_open, patch, skip_if_win32,
-)
+from celery.tests.case import Case, Mock, call, mock, patch, skip
 
 
 
 
 class test_find_option_with_arg(Case):
 class test_find_option_with_arg(Case):
@@ -60,7 +56,7 @@ class test_find_option_with_arg(Case):
         )
         )
 
 
 
 
-@skip_if_win32()
+@skip.if_win32()
 class test_fd_by_path(Case):
 class test_fd_by_path(Case):
 
 
     def test_finds(self):
     def test_finds(self):
@@ -141,7 +137,7 @@ class test_Signals(Case):
         self.assertTrue(signals.supported('INT'))
         self.assertTrue(signals.supported('INT'))
         self.assertFalse(signals.supported('SIGIMAGINARY'))
         self.assertFalse(signals.supported('SIGIMAGINARY'))
 
 
-    @skip_if_win32()
+    @skip.if_win32()
     def test_reset_alarm(self):
     def test_reset_alarm(self):
         with patch('signal.alarm') as _alarm:
         with patch('signal.alarm') as _alarm:
             signals.reset_alarm()
             signals.reset_alarm()
@@ -186,7 +182,7 @@ class test_Signals(Case):
         signals['INT'] = lambda *a: a
         signals['INT'] = lambda *a: a
 
 
 
 
-@skip_if_win32()
+@skip.if_win32()
 class test_get_fdmax(Case):
 class test_get_fdmax(Case):
 
 
     @patch('resource.getrlimit')
     @patch('resource.getrlimit')
@@ -205,7 +201,7 @@ class test_get_fdmax(Case):
             self.assertEqual(get_fdmax(None), 13)
             self.assertEqual(get_fdmax(None), 13)
 
 
 
 
-@skip_if_win32()
+@skip.if_win32()
 class test_maybe_drop_privileges(Case):
 class test_maybe_drop_privileges(Case):
 
 
     def test_on_windows(self):
     def test_on_windows(self):
@@ -322,7 +318,7 @@ class test_maybe_drop_privileges(Case):
         self.assertFalse(setuid.called)
         self.assertFalse(setuid.called)
 
 
 
 
-@skip_if_win32()
+@skip.if_win32()
 class test_setget_uid_gid(Case):
 class test_setget_uid_gid(Case):
 
 
     @patch('celery.platforms.parse_uid')
     @patch('celery.platforms.parse_uid')
@@ -380,7 +376,7 @@ class test_setget_uid_gid(Case):
             parse_gid('group')
             parse_gid('group')
 
 
 
 
-@skip_if_win32()
+@skip.if_win32()
 class test_initgroups(Case):
 class test_initgroups(Case):
 
 
     @patch('pwd.getpwuid')
     @patch('pwd.getpwuid')
@@ -416,7 +412,7 @@ class test_initgroups(Case):
                 os.initgroups = prev
                 os.initgroups = prev
 
 
 
 
-@skip_if_win32()
+@skip.if_win32()
 class test_detached(Case):
 class test_detached(Case):
 
 
     def test_without_resource(self):
     def test_without_resource(self):
@@ -431,7 +427,7 @@ class test_detached(Case):
     @patch('celery.platforms.signals')
     @patch('celery.platforms.signals')
     @patch('celery.platforms.maybe_drop_privileges')
     @patch('celery.platforms.maybe_drop_privileges')
     @patch('os.geteuid')
     @patch('os.geteuid')
-    @patch(open_fqdn)
+    @patch(mock.open_fqdn)
     def test_default(self, open, geteuid, maybe_drop,
     def test_default(self, open, geteuid, maybe_drop,
                      signals, pidlock):
                      signals, pidlock):
         geteuid.return_value = 0
         geteuid.return_value = 0
@@ -456,7 +452,7 @@ class test_detached(Case):
         pidlock.assert_called_with('/foo/bar/pid')
         pidlock.assert_called_with('/foo/bar/pid')
 
 
 
 
-@skip_if_win32()
+@skip.if_win32()
 class test_DaemonContext(Case):
 class test_DaemonContext(Case):
 
 
     @patch('os.fork')
     @patch('os.fork')
@@ -522,7 +518,7 @@ class test_DaemonContext(Case):
             x.open()
             x.open()
 
 
 
 
-@skip_if_win32()
+@skip.if_win32()
 class test_Pidfile(Case):
 class test_Pidfile(Case):
 
 
     @patch('celery.platforms.Pidfile')
     @patch('celery.platforms.Pidfile')
@@ -530,7 +526,7 @@ class test_Pidfile(Case):
         p = Pidfile.return_value = Mock()
         p = Pidfile.return_value = Mock()
         p.is_locked.return_value = True
         p.is_locked.return_value = True
         p.remove_if_stale.return_value = False
         p.remove_if_stale.return_value = False
-        with override_stdouts() as (_, err):
+        with mock.stdouts() as (_, err):
             with self.assertRaises(SystemExit):
             with self.assertRaises(SystemExit):
                 create_pidlock('/var/pid')
                 create_pidlock('/var/pid')
             self.assertIn('already exists', err.getvalue())
             self.assertIn('already exists', err.getvalue())
@@ -567,14 +563,14 @@ class test_Pidfile(Case):
         self.assertFalse(p.is_locked())
         self.assertFalse(p.is_locked())
 
 
     def test_read_pid(self):
     def test_read_pid(self):
-        with mock_open() as s:
+        with mock.open() as s:
             s.write('1816\n')
             s.write('1816\n')
             s.seek(0)
             s.seek(0)
             p = Pidfile('/var/pid')
             p = Pidfile('/var/pid')
             self.assertEqual(p.read_pid(), 1816)
             self.assertEqual(p.read_pid(), 1816)
 
 
     def test_read_pid_partially_written(self):
     def test_read_pid_partially_written(self):
-        with mock_open() as s:
+        with mock.open() as s:
             s.write('1816')
             s.write('1816')
             s.seek(0)
             s.seek(0)
             p = Pidfile('/var/pid')
             p = Pidfile('/var/pid')
@@ -584,20 +580,20 @@ class test_Pidfile(Case):
     def test_read_pid_raises_ENOENT(self):
     def test_read_pid_raises_ENOENT(self):
         exc = IOError()
         exc = IOError()
         exc.errno = errno.ENOENT
         exc.errno = errno.ENOENT
-        with mock_open(side_effect=exc):
+        with mock.open(side_effect=exc):
             p = Pidfile('/var/pid')
             p = Pidfile('/var/pid')
             self.assertIsNone(p.read_pid())
             self.assertIsNone(p.read_pid())
 
 
     def test_read_pid_raises_IOError(self):
     def test_read_pid_raises_IOError(self):
         exc = IOError()
         exc = IOError()
         exc.errno = errno.EAGAIN
         exc.errno = errno.EAGAIN
-        with mock_open(side_effect=exc):
+        with mock.open(side_effect=exc):
             p = Pidfile('/var/pid')
             p = Pidfile('/var/pid')
             with self.assertRaises(IOError):
             with self.assertRaises(IOError):
                 p.read_pid()
                 p.read_pid()
 
 
     def test_read_pid_bogus_pidfile(self):
     def test_read_pid_bogus_pidfile(self):
-        with mock_open() as s:
+        with mock.open() as s:
             s.write('eighteensixteen\n')
             s.write('eighteensixteen\n')
             s.seek(0)
             s.seek(0)
             p = Pidfile('/var/pid')
             p = Pidfile('/var/pid')
@@ -655,7 +651,7 @@ class test_Pidfile(Case):
 
 
     @patch('os.kill')
     @patch('os.kill')
     def test_remove_if_stale_process_dead(self, kill):
     def test_remove_if_stale_process_dead(self, kill):
-        with override_stdouts():
+        with mock.stdouts():
             p = Pidfile('/var/pid')
             p = Pidfile('/var/pid')
             p.read_pid = Mock()
             p.read_pid = Mock()
             p.read_pid.return_value = 1816
             p.read_pid.return_value = 1816
@@ -668,7 +664,7 @@ class test_Pidfile(Case):
             p.remove.assert_called_with()
             p.remove.assert_called_with()
 
 
     def test_remove_if_stale_broken_pid(self):
     def test_remove_if_stale_broken_pid(self):
-        with override_stdouts():
+        with mock.stdouts():
             p = Pidfile('/var/pid')
             p = Pidfile('/var/pid')
             p.read_pid = Mock()
             p.read_pid = Mock()
             p.read_pid.side_effect = ValueError()
             p.read_pid.side_effect = ValueError()
@@ -690,7 +686,7 @@ class test_Pidfile(Case):
     @patch('os.getpid')
     @patch('os.getpid')
     @patch('os.open')
     @patch('os.open')
     @patch('os.fdopen')
     @patch('os.fdopen')
-    @patch(open_fqdn)
+    @patch(mock.open_fqdn)
     def test_write_pid(self, open_, fdopen, osopen, getpid, fsync):
     def test_write_pid(self, open_, fdopen, osopen, getpid, fsync):
         getpid.return_value = 1816
         getpid.return_value = 1816
         osopen.return_value = 13
         osopen.return_value = 13
@@ -717,7 +713,7 @@ class test_Pidfile(Case):
     @patch('os.getpid')
     @patch('os.getpid')
     @patch('os.open')
     @patch('os.open')
     @patch('os.fdopen')
     @patch('os.fdopen')
-    @patch(open_fqdn)
+    @patch(mock.open_fqdn)
     def test_write_reread_fails(self, open_, fdopen,
     def test_write_reread_fails(self, open_, fdopen,
                                 osopen, getpid, fsync):
                                 osopen, getpid, fsync):
         getpid.return_value = 1816
         getpid.return_value = 1816

+ 2 - 2
celery/tests/utils/test_serialization.py

@@ -7,7 +7,7 @@ from celery.utils.serialization import (
     get_pickleable_etype,
     get_pickleable_etype,
 )
 )
 
 
-from celery.tests.case import Case, mask_modules
+from celery.tests.case import Case, mock
 
 
 
 
 class test_AAPickle(Case):
 class test_AAPickle(Case):
@@ -15,7 +15,7 @@ class test_AAPickle(Case):
     def test_no_cpickle(self):
     def test_no_cpickle(self):
         prev = sys.modules.pop('celery.utils.serialization', None)
         prev = sys.modules.pop('celery.utils.serialization', None)
         try:
         try:
-            with mask_modules('cPickle'):
+            with mock.mask_modules('cPickle'):
                 from celery.utils.serialization import pickle
                 from celery.utils.serialization import pickle
                 import pickle as orig_pickle
                 import pickle as orig_pickle
                 self.assertIs(pickle.dumps, orig_pickle.dumps)
                 self.assertIs(pickle.dumps, orig_pickle.dumps)

+ 3 - 3
celery/tests/utils/test_sysinfo.py

@@ -2,10 +2,10 @@ from __future__ import absolute_import, unicode_literals
 
 
 from celery.utils.sysinfo import load_average, df
 from celery.utils.sysinfo import load_average, df
 
 
-from celery.tests.case import Case, patch, skip_unless_symbol
+from celery.tests.case import Case, patch, skip
 
 
 
 
-@skip_unless_symbol('os.getloadavg')
+@skip.unless_symbol('os.getloadavg')
 class test_load_average(Case):
 class test_load_average(Case):
 
 
     def test_avg(self):
     def test_avg(self):
@@ -16,7 +16,7 @@ class test_load_average(Case):
             self.assertEqual(l, (0.55, 0.64, 0.7))
             self.assertEqual(l, (0.55, 0.64, 0.7))
 
 
 
 
-@skip_unless_symbol('posix.statvfs_result')
+@skip.unless_symbol('posix.statvfs_result')
 class test_df(Case):
 class test_df(Case):
 
 
     def test_df(self):
     def test_df(self):

+ 2 - 2
celery/tests/utils/test_term.py

@@ -7,10 +7,10 @@ from celery.utils import term
 from celery.utils.term import colored, fg
 from celery.utils.term import colored, fg
 from celery.five import text_t
 from celery.five import text_t
 
 
-from celery.tests.case import Case, skip_if_win32
+from celery.tests.case import Case, skip
 
 
 
 
-@skip_if_win32()
+@skip.if_win32()
 class test_colored(Case):
 class test_colored(Case):
 
 
     def setUp(self):
     def setUp(self):

+ 2 - 2
celery/tests/utils/test_threads.py

@@ -8,7 +8,7 @@ from celery.utils.threads import (
     bgThread,
     bgThread,
 )
 )
 
 
-from celery.tests.case import Case, override_stdouts, patch
+from celery.tests.case import Case, mock, patch
 
 
 
 
 class test_bgThread(Case):
 class test_bgThread(Case):
@@ -21,7 +21,7 @@ class test_bgThread(Case):
                 raise KeyError()
                 raise KeyError()
 
 
         with patch('os._exit') as _exit:
         with patch('os._exit') as _exit:
-            with override_stdouts():
+            with mock.stdouts():
                 _exit.side_effect = ValueError()
                 _exit.side_effect = ValueError()
                 t = T()
                 t = T()
                 with self.assertRaises(ValueError):
                 with self.assertRaises(ValueError):

+ 1 - 1
celery/tests/utils/test_timer2.py

@@ -52,7 +52,7 @@ class test_Timer(Case):
         t = timer2.Timer(on_tick=on_tick)
         t = timer2.Timer(on_tick=on_tick)
         ne = t._next_entry = Mock(name='_next_entry')
         ne = t._next_entry = Mock(name='_next_entry')
         ne.return_value = 3.33
         ne.return_value = 3.33
-        self.on_nth_call_do(ne, t._is_shutdown.set, 3)
+        ne.on_nth_call_do(t._is_shutdown.set, 3)
         t.run()
         t.run()
         sleep.assert_called_with(3.33)
         sleep.assert_called_with(3.33)
         on_tick.assert_has_calls([call(3.33), call(3.33), call(3.33)])
         on_tick.assert_has_calls([call(3.33), call(3.33), call(3.33)])

+ 3 - 3
celery/tests/worker/test_autoreload.py

@@ -18,7 +18,7 @@ from celery.worker.autoreload import (
     Autoreloader,
     Autoreloader,
 )
 )
 
 
-from celery.tests.case import AppCase, Case, Mock, patch, mock_open
+from celery.tests.case import AppCase, Case, Mock, mock, patch
 
 
 
 
 class test_WorkerComponent(AppCase):
 class test_WorkerComponent(AppCase):
@@ -52,11 +52,11 @@ class test_WorkerComponent(AppCase):
 class test_file_hash(Case):
 class test_file_hash(Case):
 
 
     def test_hash(self):
     def test_hash(self):
-        with mock_open() as a:
+        with mock.open() as a:
             a.write('the quick brown fox\n')
             a.write('the quick brown fox\n')
             a.seek(0)
             a.seek(0)
             A = file_hash('foo')
             A = file_hash('foo')
-        with mock_open() as b:
+        with mock.open() as b:
             b.write('the quick brown bar\n')
             b.write('the quick brown bar\n')
             b.seek(0)
             b.seek(0)
             B = file_hash('bar')
             B = file_hash('bar')

+ 4 - 2
celery/tests/worker/test_autoscale.py

@@ -6,9 +6,10 @@ from celery.concurrency.base import BasePool
 from celery.five import monotonic
 from celery.five import monotonic
 from celery.worker import state
 from celery.worker import state
 from celery.worker import autoscale
 from celery.worker import autoscale
-from celery.tests.case import AppCase, Mock, patch, sleepdeprived
 from celery.utils.objects import Bunch
 from celery.utils.objects import Bunch
 
 
+from celery.tests.case import AppCase, Mock, mock, patch
+
 
 
 class MockPool(BasePool):
 class MockPool(BasePool):
     shrink_raises_exception = False
     shrink_raises_exception = False
@@ -88,7 +89,7 @@ class test_Autoscaler(AppCase):
         x.stop()
         x.stop()
         self.assertFalse(x.joined)
         self.assertFalse(x.joined)
 
 
-    @sleepdeprived(autoscale)
+    @mock.sleepdeprived(module=autoscale)
     def test_body(self):
     def test_body(self):
         worker = Mock(name='worker')
         worker = Mock(name='worker')
         x = autoscale.Autoscaler(self.pool, 10, 3, worker=worker)
         x = autoscale.Autoscaler(self.pool, 10, 3, worker=worker)
@@ -193,6 +194,7 @@ class test_Autoscaler(AppCase):
         _exit.assert_called_with(1)
         _exit.assert_called_with(1)
         self.assertTrue(stderr.write.call_count)
         self.assertTrue(stderr.write.call_count)
 
 
+    @mock.sleepdeprived(module=autoscale)
     def test_no_negative_scale(self):
     def test_no_negative_scale(self):
         total_num_processes = []
         total_num_processes = []
         worker = Mock(name='worker')
         worker = Mock(name='worker')

+ 2 - 2
celery/tests/worker/test_components.py

@@ -7,7 +7,7 @@ from __future__ import absolute_import, unicode_literals
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
 from celery.worker.components import Beat, Hub, Pool, Timer
 from celery.worker.components import Beat, Hub, Pool, Timer
 
 
-from celery.tests.case import AppCase, Mock, patch, skip_if_win32
+from celery.tests.case import AppCase, Mock, patch, skip
 
 
 
 
 class test_Timer(AppCase):
 class test_Timer(AppCase):
@@ -60,7 +60,7 @@ class test_Pool(AppCase):
         comp.close(w)
         comp.close(w)
         comp.terminate(w)
         comp.terminate(w)
 
 
-    @skip_if_win32()
+    @skip.if_win32()
     def test_create_when_eventloop(self):
     def test_create_when_eventloop(self):
         w = Mock()
         w = Mock()
         w.use_eventloop = w.pool_putlocks = w.pool_cls.uses_semaphore = True
         w.use_eventloop = w.pool_putlocks = w.pool_cls.uses_semaphore = True

+ 2 - 4
celery/tests/worker/test_consumer.py

@@ -13,9 +13,7 @@ from celery.worker.consumer.heart import Heart
 from celery.worker.consumer.mingle import Mingle
 from celery.worker.consumer.mingle import Mingle
 from celery.worker.consumer.tasks import Tasks
 from celery.worker.consumer.tasks import Tasks
 
 
-from celery.tests.case import (
-    AppCase, ContextMock, Mock, call, patch, skip_if_python3,
-)
+from celery.tests.case import AppCase, ContextMock, Mock, call, patch, skip
 
 
 
 
 class test_Consumer(AppCase):
 class test_Consumer(AppCase):
@@ -45,7 +43,7 @@ class test_Consumer(AppCase):
         c = self.get_consumer()
         c = self.get_consumer()
         self.assertIsNone(c.task_buckets['fooxasdwx.wewe'])
         self.assertIsNone(c.task_buckets['fooxasdwx.wewe'])
 
 
-    @skip_if_python3(reason='buffer type not available')
+    @skip.if_python3(reason='buffer type not available')
     def test_dump_body_buffer(self):
     def test_dump_body_buffer(self):
         msg = Mock()
         msg = Mock()
         msg.body = 'str'
         msg.body = 'str'

+ 2 - 2
celery/tests/worker/test_request.py

@@ -48,7 +48,7 @@ from celery.tests.case import (
     TaskMessage,
     TaskMessage,
     task_message_from_sig,
     task_message_from_sig,
     patch,
     patch,
-    skip_if_python3,
+    skip,
 )
 )
 
 
 
 
@@ -123,7 +123,7 @@ def jail(app, task_id, name, args, kwargs):
     ).retval
     ).retval
 
 
 
 
-@skip_if_python3
+@skip.if_python3()
 class test_default_encode(AppCase):
 class test_default_encode(AppCase):
 
 
     def test_jython(self):
     def test_jython(self):

+ 2 - 2
celery/tests/worker/test_worker.py

@@ -30,7 +30,7 @@ from celery.utils import worker_direct
 from celery.utils.serialization import pickle
 from celery.utils.serialization import pickle
 from celery.utils.timer2 import Timer
 from celery.utils.timer2 import Timer
 
 
-from celery.tests.case import AppCase, Mock, TaskMessage, patch, todo
+from celery.tests.case import AppCase, Mock, TaskMessage, patch, skip
 
 
 
 
 def MockStep(step=None):
 def MockStep(step=None):
@@ -849,7 +849,7 @@ class test_WorkController(AppCase):
             self.worker._send_worker_shutdown()
             self.worker._send_worker_shutdown()
             ws.send.assert_called_with(sender=self.worker)
             ws.send.assert_called_with(sender=self.worker)
 
 
-    @todo('unstable test')
+    @skip.todo('unstable test')
     def test_process_shutdown_on_worker_shutdown(self):
     def test_process_shutdown_on_worker_shutdown(self):
         from celery.concurrency.prefork import process_destructor
         from celery.concurrency.prefork import process_destructor
         from celery.concurrency.asynpool import Worker
         from celery.concurrency.asynpool import Worker

+ 1 - 3
requirements/test.txt

@@ -1,3 +1 @@
--r deps/mock.txt
--r deps/nose.txt
-unittest2>=0.5.1
+case

+ 0 - 1
requirements/test3.txt

@@ -1 +0,0 @@
--r deps/nose.txt

+ 1 - 5
setup.py

@@ -171,10 +171,6 @@ install_requires = reqs('default.txt')
 if JYTHON:
 if JYTHON:
     install_requires.extend(reqs('jython.txt'))
     install_requires.extend(reqs('jython.txt'))
 
 
-# -*- Tests Requires -*-
-
-tests_require = reqs('test3.txt' if PY3 else 'test.txt')
-
 # -*- Long Description -*-
 # -*- Long Description -*-
 
 
 if os.path.exists('README.rst'):
 if os.path.exists('README.rst'):
@@ -219,7 +215,7 @@ setup(
     include_package_data=False,
     include_package_data=False,
     zip_safe=False,
     zip_safe=False,
     install_requires=install_requires,
     install_requires=install_requires,
-    tests_require=tests_require,
+    tests_require=reqs('test.txt'),
     test_suite='nose.collector',
     test_suite='nose.collector',
     classifiers=classifiers,
     classifiers=classifiers,
     entry_points=entrypoints,
     entry_points=entrypoints,

+ 1 - 5
tox.ini

@@ -4,15 +4,11 @@ envlist = 2.7,pypy,3.4,3.5,pypy3,flake8,flakeplus
 [testenv]
 [testenv]
 deps=
 deps=
     -r{toxinidir}/requirements/default.txt
     -r{toxinidir}/requirements/default.txt
+    -r{toxinidir}/requirements/test.txt
 
 
-    2.7,pypy: -r{toxinidir}/requirements/test.txt
     2.7: -r{toxinidir}/requirements/test-ci-default.txt
     2.7: -r{toxinidir}/requirements/test-ci-default.txt
-
-    3.4,3.5,pypy3: -r{toxinidir}/requirements/test3.txt
     3.4,3.5: -r{toxinidir}/requirements/test-ci-default.txt
     3.4,3.5: -r{toxinidir}/requirements/test-ci-default.txt
-
     pypy,pypy3: -r{toxinidir}/requirements/test-ci-base.txt
     pypy,pypy3: -r{toxinidir}/requirements/test-ci-base.txt
-    pypy3: -r{toxinidir}/requirements/test-pypy3.txt
 
 
 sitepackages = False
 sitepackages = False
 recreate = False
 recreate = False