Browse Source

[tests] Now depends on case

Ask Solem 9 years ago
parent
commit
d31097ec90
60 changed files with 508 additions and 1359 deletions
  1. 3 0
      celery/__init__.py
  2. 6 4
      celery/contrib/migrate.py
  3. 0 799
      celery/tests/_case.py
  4. 13 17
      celery/tests/app/test_app.py
  5. 2 2
      celery/tests/app/test_beat.py
  6. 9 9
      celery/tests/app/test_defaults.py
  7. 2 2
      celery/tests/app/test_loaders.py
  8. 86 88
      celery/tests/app/test_log.py
  9. 3 3
      celery/tests/app/test_schedules.py
  10. 2 4
      celery/tests/backends/test_amqp.py
  11. 3 4
      celery/tests/backends/test_base.py
  12. 15 17
      celery/tests/backends/test_cache.py
  13. 130 140
      celery/tests/backends/test_cassandra.py
  14. 2 4
      celery/tests/backends/test_couchbase.py
  15. 2 4
      celery/tests/backends/test_couchdb.py
  16. 6 12
      celery/tests/backends/test_database.py
  17. 2 2
      celery/tests/backends/test_elasticsearch.py
  18. 2 2
      celery/tests/backends/test_filesystem.py
  19. 6 7
      celery/tests/backends/test_mongodb.py
  20. 4 4
      celery/tests/backends/test_redis.py
  21. 2 5
      celery/tests/backends/test_riak.py
  22. 9 9
      celery/tests/bin/test_base.py
  23. 26 24
      celery/tests/bin/test_beat.py
  24. 2 2
      celery/tests/bin/test_celeryd_detach.py
  25. 2 2
      celery/tests/bin/test_events.py
  26. 2 2
      celery/tests/bin/test_multi.py
  27. 49 59
      celery/tests/bin/test_worker.py
  28. 9 5
      celery/tests/case.py
  29. 2 2
      celery/tests/concurrency/test_eventlet.py
  30. 2 2
      celery/tests/concurrency/test_gevent.py
  31. 2 2
      celery/tests/concurrency/test_pool.py
  32. 5 7
      celery/tests/concurrency/test_prefork.py
  33. 6 6
      celery/tests/concurrency/test_threads.py
  34. 5 5
      celery/tests/contrib/test_migrate.py
  35. 3 3
      celery/tests/contrib/test_rdb.py
  36. 2 2
      celery/tests/events/test_cursesmon.py
  37. 11 10
      celery/tests/events/test_snapshot.py
  38. 2 2
      celery/tests/events/test_state.py
  39. 11 13
      celery/tests/fixups/test_django.py
  40. 2 2
      celery/tests/security/case.py
  41. 3 3
      celery/tests/security/test_certificate.py
  42. 2 2
      celery/tests/security/test_security.py
  43. 0 0
      celery/tests/slow/__init__.py
  44. 2 2
      celery/tests/utils/test_datastructures.py
  45. 21 25
      celery/tests/utils/test_platforms.py
  46. 2 2
      celery/tests/utils/test_serialization.py
  47. 3 3
      celery/tests/utils/test_sysinfo.py
  48. 2 2
      celery/tests/utils/test_term.py
  49. 2 2
      celery/tests/utils/test_threads.py
  50. 1 1
      celery/tests/utils/test_timer2.py
  51. 3 3
      celery/tests/worker/test_autoreload.py
  52. 4 2
      celery/tests/worker/test_autoscale.py
  53. 2 2
      celery/tests/worker/test_components.py
  54. 2 4
      celery/tests/worker/test_consumer.py
  55. 2 2
      celery/tests/worker/test_request.py
  56. 2 2
      celery/tests/worker/test_worker.py
  57. 1 3
      requirements/test.txt
  58. 0 1
      requirements/test3.txt
  59. 1 5
      setup.py
  60. 1 5
      tox.ini

+ 3 - 0
celery/__init__.py

@@ -158,4 +158,7 @@ old_module, new_module = five.recreate_module(  # pragma: no cover
     version_info_t=version_info_t,
     maybe_patch_concurrency=maybe_patch_concurrency,
     _find_option_with_arg=_find_option_with_arg,
+    absolute_import=absolute_import,
+    unicode_literals=unicode_literals,
+    print_function=print_function,
 )

+ 6 - 4
celery/contrib/migrate.py

@@ -21,10 +21,12 @@ from celery.app import app_or_default
 from celery.five import string, string_t
 from celery.utils import worker_direct
 
-__all__ = ['StopFiltering', 'State', 'republish', 'migrate_task',
-           'migrate_tasks', 'move', 'task_id_eq', 'task_id_in',
-           'start_filter', 'move_task_by_id', 'move_by_idmap',
-           'move_by_taskmap', 'move_direct', 'move_direct_by_id']
+__all__ = [
+    'StopFiltering', 'State', 'republish', 'migrate_task',
+    'migrate_tasks', 'move', 'task_id_eq', 'task_id_in',
+    'start_filter', 'move_task_by_id', 'move_by_idmap',
+    'move_by_taskmap', 'move_direct', 'move_direct_by_id',
+]
 
 MOVING_PROGRESS_FMT = """\
 Moving task {state.filtered}/{state.strtotal}: \

+ 0 - 799
celery/tests/_case.py

@@ -1,799 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-import importlib
-import inspect
-import io
-import logging
-import os
-import platform
-import re
-import sys
-import time
-import types
-import warnings
-
-from contextlib import contextmanager
-from functools import partial, wraps
-from six import (
-    iteritems as items,
-    itervalues as values,
-    string_types,
-    reraise,
-)
-from six.moves import builtins
-
-from nose import SkipTest
-
-try:
-    import unittest  # noqa
-    unittest.skip
-    from unittest.util import safe_repr, unorderable_list_difference
-except AttributeError:
-    import unittest2 as unittest  # noqa
-    from unittest2.util import safe_repr, unorderable_list_difference  # noqa
-
-try:
-    from unittest import mock
-except ImportError:
-    import mock  # noqa
-
-__all__ = [
-    'ANY', 'Case', 'ContextMock', 'MagicMock', 'Mock', 'MockCallbacks',
-    'call', 'patch', 'sentinel',
-
-    'mock_open', 'mock_context', 'mock_module',
-    'patch_modules', 'reset_modules', 'sys_platform', 'pypy_version',
-    'platform_pyimp', 'replace_module_value', 'override_stdouts',
-    'mask_modules', 'sleepdeprived', 'mock_environ', 'wrap_logger',
-    'restore_logging',
-
-    'todo', 'skip', 'skip_if_darwin', 'skip_if_environ',
-    'skip_if_jython', 'skip_if_platform', 'skip_if_pypy', 'skip_if_python3',
-    'skip_if_win32', 'skip_unless_module', 'skip_unless_symbol',
-]
-
-patch = mock.patch
-call = mock.call
-sentinel = mock.sentinel
-MagicMock = mock.MagicMock
-ANY = mock.ANY
-
-PY3 = sys.version_info[0] == 3
-if PY3:
-    open_fqdn = 'builtins.open'
-    module_name_t = str
-else:
-    open_fqdn = '__builtin__.open'  # noqa
-    module_name_t = bytes  # noqa
-
-StringIO = io.StringIO
-_SIO_write = StringIO.write
-_SIO_init = StringIO.__init__
-
-
-def symbol_by_name(name, aliases={}, imp=None, package=None,
-                   sep='.', default=None, **kwargs):
-    """Get symbol by qualified name.
-
-    The name should be the full dot-separated path to the class::
-
-        modulename.ClassName
-
-    Example::
-
-        celery.concurrency.processes.TaskPool
-                                    ^- class name
-
-    or using ':' to separate module and symbol::
-
-        celery.concurrency.processes:TaskPool
-
-    If `aliases` is provided, a dict containing short name/long name
-    mappings, the name is looked up in the aliases first.
-
-    Examples:
-
-        >>> symbol_by_name('celery.concurrency.processes.TaskPool')
-        <class 'celery.concurrency.processes.TaskPool'>
-
-        >>> symbol_by_name('default', {
-        ...     'default': 'celery.concurrency.processes.TaskPool'})
-        <class 'celery.concurrency.processes.TaskPool'>
-
-        # Does not try to look up non-string names.
-        >>> from celery.concurrency.processes import TaskPool
-        >>> symbol_by_name(TaskPool) is TaskPool
-        True
-
-    """
-    if imp is None:
-        imp = importlib.import_module
-
-    if not isinstance(name, string_types):
-        return name                                 # already a class
-
-    name = aliases.get(name) or name
-    sep = ':' if ':' in name else sep
-    module_name, _, cls_name = name.rpartition(sep)
-    if not module_name:
-        cls_name, module_name = None, package if package else cls_name
-    try:
-        try:
-            module = imp(module_name, package=package, **kwargs)
-        except ValueError as exc:
-            reraise(ValueError,
-                    ValueError("Couldn't import {0!r}: {1}".format(name, exc)),
-                    sys.exc_info()[2])
-        return getattr(module, cls_name) if cls_name else module
-    except (ImportError, AttributeError):
-        if default is None:
-            raise
-    return default
-
-
-class WhateverIO(StringIO):
-
-    def __init__(self, v=None, *a, **kw):
-        _SIO_init(self, v.decode() if isinstance(v, bytes) else v, *a, **kw)
-
-    def write(self, data):
-        _SIO_write(self, data.decode() if isinstance(data, bytes) else data)
-
-
-def noop(*args, **kwargs):
-    pass
-
-
-class Mock(mock.Mock):
-
-    def __init__(self, *args, **kwargs):
-        attrs = kwargs.pop('attrs', None) or {}
-        super(Mock, self).__init__(*args, **kwargs)
-        for attr_name, attr_value in items(attrs):
-            setattr(self, attr_name, attr_value)
-
-
-class _ContextMock(Mock):
-    """Dummy class implementing __enter__ and __exit__
-    as the :keyword:`with` statement requires these to be implemented
-    in the class, not just the instance."""
-
-    def __enter__(self):
-        return self
-
-    def __exit__(self, *exc_info):
-        pass
-
-
-def ContextMock(*args, **kwargs):
-    obj = _ContextMock(*args, **kwargs)
-    obj.attach_mock(_ContextMock(), '__enter__')
-    obj.attach_mock(_ContextMock(), '__exit__')
-    obj.__enter__.return_value = obj
-    # if __exit__ return a value the exception is ignored,
-    # so it must return None here.
-    obj.__exit__.return_value = None
-    return obj
-
-
-def _bind(f, o):
-    @wraps(f)
-    def bound_meth(*fargs, **fkwargs):
-        return f(o, *fargs, **fkwargs)
-    return bound_meth
-
-
-if PY3:  # pragma: no cover
-    def _get_class_fun(meth):
-        return meth
-else:
-    def _get_class_fun(meth):
-        return meth.__func__
-
-
-class MockCallbacks(object):
-
-    def __new__(cls, *args, **kwargs):
-        r = Mock(name=cls.__name__)
-        _get_class_fun(cls.__init__)(r, *args, **kwargs)
-        for key, value in items(vars(cls)):
-            if key not in ('__dict__', '__weakref__', '__new__', '__init__'):
-                if inspect.ismethod(value) or inspect.isfunction(value):
-                    r.__getattr__(key).side_effect = _bind(value, r)
-                else:
-                    r.__setattr__(key, value)
-        return r
-
-
-# -- adds assertWarns from recent unittest2, not in Python 2.7.
-
-class _AssertRaisesBaseContext(object):
-
-    def __init__(self, expected, test_case, callable_obj=None,
-                 expected_regex=None):
-        self.expected = expected
-        self.failureException = test_case.failureException
-        self.obj_name = None
-        if isinstance(expected_regex, string_types):
-            expected_regex = re.compile(expected_regex)
-        self.expected_regex = expected_regex
-
-
-def _is_magic_module(m):
-    # some libraries create custom module types that are lazily
-    # lodaded, e.g. Django installs some modules in sys.modules that
-    # will load _tkinter and other shit when touched.
-
-    # pyflakes refuses to accept 'noqa' for this isinstance.
-    cls, modtype = type(m), types.ModuleType
-    try:
-        variables = vars(cls)
-    except TypeError:
-        return True
-    else:
-        return (cls is not modtype and (
-            '__getattr__' in variables or
-            '__getattribute__' in variables))
-
-
-class _AssertWarnsContext(_AssertRaisesBaseContext):
-    """A context manager used to implement TestCase.assertWarns* methods."""
-
-    def __enter__(self):
-        # The __warningregistry__'s need to be in a pristine state for tests
-        # to work properly.
-        warnings.resetwarnings()
-        for v in list(values(sys.modules)):
-            # do not evaluate Django moved modules and other lazily
-            # initialized modules.
-            if v and not _is_magic_module(v):
-                # use raw __getattribute__ to protect even better from
-                # lazily loaded modules
-                try:
-                    object.__getattribute__(v, '__warningregistry__')
-                except AttributeError:
-                    pass
-                else:
-                    object.__setattr__(v, '__warningregistry__', {})
-        self.warnings_manager = warnings.catch_warnings(record=True)
-        self.warnings = self.warnings_manager.__enter__()
-        warnings.simplefilter('always', self.expected)
-        return self
-
-    def __exit__(self, exc_type, exc_value, tb):
-        self.warnings_manager.__exit__(exc_type, exc_value, tb)
-        if exc_type is not None:
-            # let unexpected exceptions pass through
-            return
-        try:
-            exc_name = self.expected.__name__
-        except AttributeError:
-            exc_name = str(self.expected)
-        first_matching = None
-        for m in self.warnings:
-            w = m.message
-            if not isinstance(w, self.expected):
-                continue
-            if first_matching is None:
-                first_matching = w
-            if (self.expected_regex is not None and
-                    not self.expected_regex.search(str(w))):
-                continue
-            # store warning for later retrieval
-            self.warning = w
-            self.filename = m.filename
-            self.lineno = m.lineno
-            return
-        # Now we simply try to choose a helpful failure message
-        if first_matching is not None:
-            raise self.failureException(
-                '%r does not match %r' % (
-                    self.expected_regex.pattern, str(first_matching)))
-        if self.obj_name:
-            raise self.failureException(
-                '%s not triggered by %s' % (exc_name, self.obj_name))
-        else:
-            raise self.failureException('%s not triggered' % exc_name)
-
-
-class Case(unittest.TestCase):
-    DeprecationWarning = DeprecationWarning
-    PendingDeprecationWarning = PendingDeprecationWarning
-
-    def patch(self, *path, **options):
-        manager = patch('.'.join(path), **options)
-        patched = manager.start()
-        self.addCleanup(manager.stop)
-        return patched
-
-    def mock_modules(self, *mods):
-        modules = []
-        for mod in mods:
-            mod = mod.split('.')
-            modules.extend(reversed([
-                '.'.join(mod[:-i] if i else mod) for i in range(len(mod))
-            ]))
-        modules = sorted(set(modules))
-        return self.wrap_context(mock_module(*modules))
-
-    def on_nth_call_do(self, mock, side_effect, n=1):
-
-        def on_call(*args, **kwargs):
-            if mock.call_count >= n:
-                mock.side_effect = side_effect
-            return mock.return_value
-        mock.side_effect = on_call
-        return mock
-
-    def on_nth_call_return(self, mock, retval, n=1):
-
-        def on_call(*args, **kwargs):
-            if mock.call_count >= n:
-                mock.return_value = retval
-            return mock.return_value
-        mock.side_effect = on_call
-        return mock
-
-    def mask_modules(self, *modules):
-        self.wrap_context(mask_modules(*modules))
-
-    def wrap_context(self, context):
-        ret = context.__enter__()
-        self.addCleanup(partial(context.__exit__, None, None, None))
-        return ret
-
-    def mock_environ(self, env_name, env_value):
-        return self.wrap_context(mock_environ(env_name, env_value))
-
-    def assertWarns(self, expected_warning):
-        return _AssertWarnsContext(expected_warning, self, None)
-
-    def assertWarnsRegex(self, expected_warning, expected_regex):
-        return _AssertWarnsContext(expected_warning, self,
-                                   None, expected_regex)
-
-    @contextmanager
-    def assertDeprecated(self):
-        with self.assertWarnsRegex(self.DeprecationWarning,
-                                   r'scheduled for removal'):
-            yield
-
-    @contextmanager
-    def assertPendingDeprecation(self):
-        with self.assertWarnsRegex(self.PendingDeprecationWarning,
-                                   r'scheduled for deprecation'):
-            yield
-
-    def assertDictContainsSubset(self, expected, actual, msg=None):
-        missing, mismatched = [], []
-
-        for key, value in items(expected):
-            if key not in actual:
-                missing.append(key)
-            elif value != actual[key]:
-                mismatched.append('%s, expected: %s, actual: %s' % (
-                    safe_repr(key), safe_repr(value),
-                    safe_repr(actual[key])))
-
-        if not (missing or mismatched):
-            return
-
-        standard_msg = ''
-        if missing:
-            standard_msg = 'Missing: %s' % ','.join(map(safe_repr, missing))
-
-        if mismatched:
-            if standard_msg:
-                standard_msg += '; '
-            standard_msg += 'Mismatched values: %s' % (
-                ','.join(mismatched))
-
-        self.fail(self._formatMessage(msg, standard_msg))
-
-    def assertItemsEqual(self, expected_seq, actual_seq, msg=None):
-        missing = unexpected = None
-        try:
-            expected = sorted(expected_seq)
-            actual = sorted(actual_seq)
-        except TypeError:
-            # Unsortable items (example: set(), complex(), ...)
-            expected = list(expected_seq)
-            actual = list(actual_seq)
-            missing, unexpected = unorderable_list_difference(
-                expected, actual)
-        else:
-            return self.assertSequenceEqual(expected, actual, msg=msg)
-
-        errors = []
-        if missing:
-            errors.append(
-                'Expected, but missing:\n    %s' % (safe_repr(missing),)
-            )
-        if unexpected:
-            errors.append(
-                'Unexpected, but present:\n    %s' % (safe_repr(unexpected),)
-            )
-        if errors:
-            standardMsg = '\n'.join(errors)
-            self.fail(self._formatMessage(msg, standardMsg))
-
-
-class _CallableContext(object):
-
-    def __init__(self, context, cargs, ckwargs, fun):
-        self.context = context
-        self.cargs = cargs
-        self.ckwargs = ckwargs
-        self.fun = fun
-
-    def __call__(self, *args, **kwargs):
-        return self.fun(*args, **kwargs)
-
-    def __enter__(self):
-        self.ctx = self.context(*self.cargs, **self.ckwargs)
-        return self.ctx.__enter__()
-
-    def __exit__(self, *einfo):
-        if self.ctx:
-            return self.ctx.__exit__(*einfo)
-
-
-def decorator(predicate):
-
-    @wraps(predicate)
-    def take_arguments(*pargs, **pkwargs):
-
-        @wraps(predicate)
-        def decorator(cls):
-            if inspect.isclass(cls):
-                orig_setup = cls.setUp
-                orig_teardown = cls.tearDown
-
-                @wraps(cls.setUp)
-                def around_setup(*args, **kwargs):
-                    try:
-                        contexts = args[0].__rb3dc_contexts__
-                    except AttributeError:
-                        contexts = args[0].__rb3dc_contexts__ = []
-                    p = predicate(*pargs, **pkwargs)
-                    p.__enter__()
-                    contexts.append(p)
-                    return orig_setup(*args, **kwargs)
-                around_setup.__wrapped__ = cls.setUp
-                cls.setUp = around_setup
-
-                @wraps(cls.tearDown)
-                def around_teardown(*args, **kwargs):
-                    try:
-                        contexts = args[0].__rb3dc_contexts__
-                    except AttributeError:
-                        pass
-                    else:
-                        for context in contexts:
-                            context.__exit__(*sys.exc_info())
-                    orig_teardown(*args, **kwargs)
-                around_teardown.__wrapped__ = cls.tearDown
-                cls.tearDown = around_teardown
-
-                return cls
-            else:
-                @wraps(cls)
-                def around_case(*args, **kwargs):
-                    with predicate(*pargs, **pkwargs):
-                        return cls(*args, **kwargs)
-                return around_case
-
-        if len(pargs) == 1 and callable(pargs[0]):
-            fun, pargs = pargs[0], ()
-            return decorator(fun)
-        return _CallableContext(predicate, pargs, pkwargs, decorator)
-    return take_arguments
-
-
-@decorator
-@contextmanager
-def skip_unless_module(module, name=None):
-    try:
-        importlib.import_module(module)
-    except (ImportError, OSError):
-        raise SkipTest('module not installed: {0}'.format(name or module))
-    yield
-
-
-@decorator
-@contextmanager
-def skip_unless_symbol(symbol, name=None):
-    try:
-        symbol_by_name(symbol)
-    except (AttributeError, ImportError):
-        raise SkipTest('missing symbol {0}'.format(name or symbol))
-    yield
-
-
-def get_logger_handlers(logger):
-    return [
-        h for h in logger.handlers
-        if not isinstance(h, logging.NullHandler)
-    ]
-
-
-@decorator
-@contextmanager
-def wrap_logger(logger, loglevel=logging.ERROR):
-    old_handlers = get_logger_handlers(logger)
-    sio = WhateverIO()
-    siohandler = logging.StreamHandler(sio)
-    logger.handlers = [siohandler]
-
-    try:
-        yield sio
-    finally:
-        logger.handlers = old_handlers
-
-
-@decorator
-@contextmanager
-def mock_environ(env_name, env_value):
-    sentinel = object()
-    prev_val = os.environ.get(env_name, sentinel)
-    os.environ[env_name] = env_value
-    try:
-        yield env_value
-    finally:
-        if prev_val is sentinel:
-            os.environ.pop(env_name, None)
-        else:
-            os.environ[env_name] = prev_val
-
-
-@decorator
-@contextmanager
-def sleepdeprived(module=time):
-    old_sleep, module.sleep = module.sleep, noop
-    try:
-        yield
-    finally:
-        module.sleep = old_sleep
-
-
-@decorator
-@contextmanager
-def skip_if_python3(reason='incompatible'):
-    if PY3:
-        raise SkipTest('Python3: {0}'.format(reason))
-    yield
-
-
-@decorator
-@contextmanager
-def skip_if_environ(env_var_name):
-    if os.environ.get(env_var_name):
-        raise SkipTest('envvar {0} set'.format(env_var_name))
-    yield
-
-
-@decorator
-@contextmanager
-def _skip_test(reason, sign):
-    raise SkipTest('{0}: {1}'.format(sign, reason))
-    yield
-todo = partial(_skip_test, sign='TODO')
-skip = partial(_skip_test, sign='SKIP')
-
-
-# Taken from
-# http://bitbucket.org/runeh/snippets/src/tip/missing_modules.py
-@decorator
-@contextmanager
-def mask_modules(*modnames):
-    """Ban some modules from being importable inside the context
-
-    For example:
-
-        >>> with mask_modules('sys'):
-        ...     try:
-        ...         import sys
-        ...     except ImportError:
-        ...         print('sys not found')
-        sys not found
-
-        >>> import sys  # noqa
-        >>> sys.version
-        (2, 5, 2, 'final', 0)
-
-    """
-    realimport = builtins.__import__
-
-    def myimp(name, *args, **kwargs):
-        if name in modnames:
-            raise ImportError('No module named %s' % name)
-        else:
-            return realimport(name, *args, **kwargs)
-
-    builtins.__import__ = myimp
-    try:
-        yield True
-    finally:
-        builtins.__import__ = realimport
-
-
-@decorator
-@contextmanager
-def override_stdouts():
-    """Override `sys.stdout` and `sys.stderr` with `WhateverIO`."""
-    prev_out, prev_err = sys.stdout, sys.stderr
-    prev_rout, prev_rerr = sys.__stdout__, sys.__stderr__
-    mystdout, mystderr = WhateverIO(), WhateverIO()
-    sys.stdout = sys.__stdout__ = mystdout
-    sys.stderr = sys.__stderr__ = mystderr
-
-    try:
-        yield mystdout, mystderr
-    finally:
-        sys.stdout = prev_out
-        sys.stderr = prev_err
-        sys.__stdout__ = prev_rout
-        sys.__stderr__ = prev_rerr
-
-
-@decorator
-@contextmanager
-def replace_module_value(module, name, value=None):
-    has_prev = hasattr(module, name)
-    prev = getattr(module, name, None)
-    if value:
-        setattr(module, name, value)
-    else:
-        try:
-            delattr(module, name)
-        except AttributeError:
-            pass
-    try:
-        yield
-    finally:
-        if prev is not None:
-            setattr(module, name, prev)
-        if not has_prev:
-            try:
-                delattr(module, name)
-            except AttributeError:
-                pass
-pypy_version = partial(
-    replace_module_value, sys, 'pypy_version_info',
-)
-platform_pyimp = partial(
-    replace_module_value, platform, 'python_implementation',
-)
-
-
-@decorator
-@contextmanager
-def sys_platform(value):
-    prev, sys.platform = sys.platform, value
-    try:
-        yield
-    finally:
-        sys.platform = prev
-
-
-@decorator
-@contextmanager
-def reset_modules(*modules):
-    prev = {k: sys.modules.pop(k) for k in modules if k in sys.modules}
-    try:
-        yield
-    finally:
-        sys.modules.update(prev)
-
-
-@decorator
-@contextmanager
-def patch_modules(*modules):
-    prev = {}
-    for mod in modules:
-        prev[mod] = sys.modules.get(mod)
-        sys.modules[mod] = types.ModuleType(module_name_t(mod))
-    try:
-        yield
-    finally:
-        for name, mod in items(prev):
-            if mod is None:
-                sys.modules.pop(name, None)
-            else:
-                sys.modules[name] = mod
-
-
-@decorator
-@contextmanager
-def mock_module(*names):
-    prev = {}
-
-    class MockModule(types.ModuleType):
-
-        def __getattr__(self, attr):
-            setattr(self, attr, Mock())
-            return types.ModuleType.__getattribute__(self, attr)
-
-    mods = []
-    for name in names:
-        try:
-            prev[name] = sys.modules[name]
-        except KeyError:
-            pass
-        mod = sys.modules[name] = MockModule(module_name_t(name))
-        mods.append(mod)
-    try:
-        yield mods
-    finally:
-        for name in names:
-            try:
-                sys.modules[name] = prev[name]
-            except KeyError:
-                try:
-                    del(sys.modules[name])
-                except KeyError:
-                    pass
-
-
-@contextmanager
-def mock_context(mock, typ=Mock):
-    context = mock.return_value = Mock()
-    context.__enter__ = typ()
-    context.__exit__ = typ()
-
-    def on_exit(*x):
-        if x[0]:
-            reraise(x[0], x[1], x[2])
-    context.__exit__.side_effect = on_exit
-    context.__enter__.return_value = context
-    try:
-        yield context
-    finally:
-        context.reset()
-
-
-@decorator
-@contextmanager
-def mock_open(typ=WhateverIO, side_effect=None):
-    with patch(open_fqdn) as open_:
-        with mock_context(open_) as context:
-            if side_effect is not None:
-                context.__enter__.side_effect = side_effect
-            val = context.__enter__.return_value = typ()
-            val.__exit__ = Mock()
-            yield val
-
-
-@decorator
-@contextmanager
-def skip_if_platform(platform_name, name=None):
-    if sys.platform.startswith(platform_name):
-        raise SkipTest('does not work on {0}'.format(platform_name or name))
-    yield
-skip_if_jython = partial(skip_if_platform, 'java', name='Jython')
-skip_if_win32 = partial(skip_if_platform, 'win32', name='Windows')
-skip_if_darwin = partial(skip_if_platform, 'darwin', name='OS X')
-
-
-@decorator
-@contextmanager
-def skip_if_pypy():
-    if getattr(sys, 'pypy_version_info', None):
-        raise SkipTest('does not work on PyPy')
-    yield
-
-
-@decorator
-@contextmanager
-def restore_logging():
-    outs = sys.stdout, sys.stderr, sys.__stdout__, sys.__stderr__
-    root = logging.getLogger()
-    level = root.level
-    handlers = root.handlers
-
-    try:
-        yield
-    finally:
-        sys.stdout, sys.stderr, sys.__stdout__, sys.__stderr__ = outs
-        root.level = level
-        root.handlers[:] = handlers

+ 13 - 17
celery/tests/app/test_app.py

@@ -29,12 +29,8 @@ from celery.tests.case import (
     Case,
     ContextMock,
     depends_on_current_app,
-    mask_modules,
+    mock,
     patch,
-    platform_pyimp,
-    sys_platform,
-    pypy_version,
-    mock_environ,
 )
 from celery.utils import uuid
 from celery.utils.mail import ErrorMail
@@ -236,7 +232,7 @@ class test_App(AppCase):
             ['A', 'B', 'C', 'D', 'E', 'F'], related_name='tasks',
         )
 
-    @mock_environ('CELERY_BROKER_URL', '')
+    @mock.environ('CELERY_BROKER_URL', '')
     def test_with_broker(self):
         with self.Celery(broker='foo://baribaz') as app:
             self.assertEqual(app.conf.broker_url, 'foo://baribaz')
@@ -850,7 +846,7 @@ class test_App(AppCase):
         self.assertIn('add2', self.app.conf.beat_schedule)
 
     def test_pool_no_multiprocessing(self):
-        with mask_modules('multiprocessing.util'):
+        with mock.mask_modules('multiprocessing.util'):
             pool = self.app.pool
             self.assertIs(pool, self.app._pool)
 
@@ -953,26 +949,26 @@ class test_debugging_utils(AppCase):
 class test_pyimplementation(AppCase):
 
     def test_platform_python_implementation(self):
-        with platform_pyimp(lambda: 'Xython'):
+        with mock.platform_pyimp(lambda: 'Xython'):
             self.assertEqual(pyimplementation(), 'Xython')
 
     def test_platform_jython(self):
-        with platform_pyimp():
-            with sys_platform('java 1.6.51'):
+        with mock.platform_pyimp():
+            with mock.sys_platform('java 1.6.51'):
                 self.assertIn('Jython', pyimplementation())
 
     def test_platform_pypy(self):
-        with platform_pyimp():
-            with sys_platform('darwin'):
-                with pypy_version((1, 4, 3)):
+        with mock.platform_pyimp():
+            with mock.sys_platform('darwin'):
+                with mock.pypy_version((1, 4, 3)):
                     self.assertIn('PyPy', pyimplementation())
-                with pypy_version((1, 4, 3, 'a4')):
+                with mock.pypy_version((1, 4, 3, 'a4')):
                     self.assertIn('PyPy', pyimplementation())
 
     def test_platform_fallback(self):
-        with platform_pyimp():
-            with sys_platform('darwin'):
-                with pypy_version():
+        with mock.platform_pyimp():
+            with mock.sys_platform('darwin'):
+                with mock.pypy_version():
                     self.assertEqual('CPython', pyimplementation())
 
 

+ 2 - 2
celery/tests/app/test_beat.py

@@ -11,7 +11,7 @@ from celery.schedules import schedule
 from celery.utils import uuid
 from celery.utils.objects import Bunch
 
-from celery.tests.case import AppCase, Mock, call, patch, skip_unless_module
+from celery.tests.case import AppCase, Mock, call, patch, skip
 
 
 class MockShelve(dict):
@@ -485,7 +485,7 @@ class test_Service(AppCase):
 
 class test_EmbeddedService(AppCase):
 
-    @skip_unless_module('_multiprocessing', name='multiprocessing')
+    @skip.unless_module('_multiprocessing', name='multiprocessing')
     def test_start_stop_process(self):
         from billiard.process import Process
 

+ 9 - 9
celery/tests/app/test_defaults.py

@@ -10,7 +10,7 @@ from celery.app.defaults import (
 )
 from celery.five import values
 
-from celery.tests.case import AppCase, pypy_version, sys_platform
+from celery.tests.case import AppCase, mock
 
 
 class test_defaults(AppCase):
@@ -29,15 +29,15 @@ class test_defaults(AppCase):
         val = object()
         self.assertIs(self.defaults.Option.typemap['any'](val), val)
 
+    @mock.sys_platform('darwin')
+    @mock.pypy_version((1, 4, 0))
     def test_default_pool_pypy_14(self):
-        with sys_platform('darwin'):
-            with pypy_version((1, 4, 0)):
-                self.assertEqual(self.defaults.DEFAULT_POOL, 'solo')
+        self.assertEqual(self.defaults.DEFAULT_POOL, 'solo')
 
+    @mock.sys_platform('darwin')
+    @mock.pypy_version((1, 5, 0))
     def test_default_pool_pypy_15(self):
-        with sys_platform('darwin'):
-            with pypy_version((1, 5, 0)):
-                self.assertEqual(self.defaults.DEFAULT_POOL, 'prefork')
+        self.assertEqual(self.defaults.DEFAULT_POOL, 'prefork')
 
     def test_compat_indices(self):
         self.assertFalse(any(key.isupper() for key in DEFAULTS))
@@ -54,9 +54,9 @@ class test_defaults(AppCase):
         for key in _TO_OLD_KEY:
             self.assertIn(key, SETTING_KEYS)
 
+    @mock.sys_platform('java 1.6.51')
     def test_default_pool_jython(self):
-        with sys_platform('java 1.6.51'):
-            self.assertEqual(self.defaults.DEFAULT_POOL, 'threads')
+        self.assertEqual(self.defaults.DEFAULT_POOL, 'threads')
 
     def test_find(self):
         find = self.defaults.find

+ 2 - 2
celery/tests/app/test_loaders.py

@@ -13,7 +13,7 @@ from celery.loaders.app import AppLoader
 from celery.utils.imports import NotAPackage
 from celery.utils.mail import SendmailWarning
 
-from celery.tests.case import AppCase, Case, Mock, mock_environ, patch
+from celery.tests.case import AppCase, Case, Mock, mock, patch
 
 
 class DummyLoader(base.BaseLoader):
@@ -144,7 +144,7 @@ class test_DefaultLoader(AppCase):
             l.read_configuration(fail_silently=False)
 
     @patch('celery.loaders.base.find_module')
-    @mock_environ('CELERY_CONFIG_MODULE', 'celeryconfig.py')
+    @mock.environ('CELERY_CONFIG_MODULE', 'celeryconfig.py')
     def test_read_configuration_py_in_name(self, find_module):
         find_module.side_effect = NotAPackage()
         l = default.Loader(app=self.app)

+ 86 - 88
celery/tests/app/test_log.py

@@ -20,11 +20,9 @@ from celery.utils.log import (
     in_sighandler,
     logger_isa,
 )
-from celery.tests.case import (
-    AppCase, Mock, mask_modules, skip_if_python3,
-    override_stdouts, patch, wrap_logger, restore_logging,
-)
-from celery.tests._case import get_logger_handlers
+
+from case.utils import get_logger_handlers
+from celery.tests.case import AppCase, Mock, mock, patch, skip
 
 
 class test_TaskFormatter(AppCase):
@@ -156,7 +154,7 @@ class test_ColorFormatter(AppCase):
         self.assertIn('<Unrepresentable', msg)
         self.assertEqual(safe_str.call_count, 1)
 
-    @skip_if_python3()
+    @skip.if_python3()
     @patch('celery.utils.log.safe_str')
     def test_format_raises_no_color(self, safe_str):
         x = ColorFormatter(use_color=False)
@@ -184,14 +182,14 @@ class test_default_logger(AppCase):
         logger = get_logger(base_logger.name)
         self.assertIs(logger.parent, logging.root)
 
+    @mock.restore_logging()
     def test_setup_logging_subsystem_misc(self):
-        with restore_logging():
-            self.app.log.setup_logging_subsystem(loglevel=None)
+        self.app.log.setup_logging_subsystem(loglevel=None)
 
+    @mock.restore_logging()
     def test_setup_logging_subsystem_misc2(self):
-        with restore_logging():
-            self.app.conf.worker_hijack_root_logger = True
-            self.app.log.setup_logging_subsystem()
+        self.app.conf.worker_hijack_root_logger = True
+        self.app.log.setup_logging_subsystem()
 
     def test_get_default_logger(self):
         self.assertTrue(self.app.log.get_default_logger())
@@ -202,19 +200,19 @@ class test_default_logger(AppCase):
         self.app.log._configure_logger(None, sys.stderr, None, '', False)
         logger.handlers[:] = []
 
+    @mock.restore_logging()
     def test_setup_logging_subsystem_colorize(self):
-        with restore_logging():
-            self.app.log.setup_logging_subsystem(colorize=None)
-            self.app.log.setup_logging_subsystem(colorize=True)
+        self.app.log.setup_logging_subsystem(colorize=None)
+        self.app.log.setup_logging_subsystem(colorize=True)
 
+    @mock.restore_logging()
     def test_setup_logging_subsystem_no_mputil(self):
-        with restore_logging():
-            with mask_modules('billiard.util'):
-                self.app.log.setup_logging_subsystem()
+        with mock.mask_modules('billiard.util'):
+            self.app.log.setup_logging_subsystem()
 
     def _assertLog(self, logger, logmsg, loglevel=logging.ERROR):
 
-        with wrap_logger(logger, loglevel=loglevel) as sio:
+        with mock.wrap_logger(logger, loglevel=loglevel) as sio:
             logger.log(loglevel, logmsg)
             return sio.getvalue().strip()
 
@@ -226,30 +224,30 @@ class test_default_logger(AppCase):
         val = self._assertLog(logger, logmsg, loglevel=loglevel)
         return self.assertFalse(val, reason)
 
+    @mock.restore_logging()
     def test_setup_logger(self):
-        with restore_logging():
-            logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
-                                       root=False, colorize=True)
-            logger.handlers = []
-            self.app.log.already_setup = False
-            logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
-                                       root=False, colorize=None)
-            self.assertIs(
-                get_logger_handlers(logger)[0].stream, sys.__stderr__,
-                'setup_logger logs to stderr without logfile argument.',
-            )
-
+        logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
+                                   root=False, colorize=True)
+        logger.handlers = []
+        self.app.log.already_setup = False
+        logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
+                                   root=False, colorize=None)
+        self.assertIs(
+            get_logger_handlers(logger)[0].stream, sys.__stderr__,
+            'setup_logger logs to stderr without logfile argument.',
+        )
+
+    @mock.restore_logging()
     def test_setup_logger_no_handlers_stream(self):
-        with restore_logging():
-            l = self.get_logger()
-            l.handlers = []
+        l = self.get_logger()
+        l.handlers = []
 
-            with override_stdouts() as outs:
-                stdout, stderr = outs
-                l = self.setup_logger(logfile=sys.stderr,
-                                      loglevel=logging.INFO, root=False)
-                l.info('The quick brown fox...')
-                self.assertIn('The quick brown fox...', stderr.getvalue())
+        with mock.stdouts() as outs:
+            stdout, stderr = outs
+            l = self.setup_logger(logfile=sys.stderr,
+                                  loglevel=logging.INFO, root=False)
+            l.info('The quick brown fox...')
+            self.assertIn('The quick brown fox...', stderr.getvalue())
 
     @patch('os.fstat')
     def test_setup_logger_no_handlers_file(self, *args):
@@ -257,7 +255,7 @@ class test_default_logger(AppCase):
         _open = ('builtins.open' if sys.version_info[0] == 3
                  else '__builtin__.open')
         with patch(_open) as osopen:
-            with restore_logging():
+            with mock.restore_logging():
                 files = defaultdict(StringIO)
 
                 def open_file(filename, *args, **kwargs):
@@ -277,59 +275,59 @@ class test_default_logger(AppCase):
                 )
                 self.assertIn(tempfile, files)
 
+    @mock.restore_logging()
     def test_redirect_stdouts(self):
-        with restore_logging():
-            logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
-                                       root=False)
-            try:
-                with wrap_logger(logger) as sio:
-                    self.app.log.redirect_stdouts_to_logger(
-                        logger, loglevel=logging.ERROR,
-                    )
-                    logger.error('foo')
-                    self.assertIn('foo', sio.getvalue())
-                    self.app.log.redirect_stdouts_to_logger(
-                        logger, stdout=False, stderr=False,
-                    )
-            finally:
-                sys.stdout, sys.stderr = sys.__stdout__, sys.__stderr__
+        logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
+                                   root=False)
+        try:
+            with mock.wrap_logger(logger) as sio:
+                self.app.log.redirect_stdouts_to_logger(
+                    logger, loglevel=logging.ERROR,
+                )
+                logger.error('foo')
+                self.assertIn('foo', sio.getvalue())
+                self.app.log.redirect_stdouts_to_logger(
+                    logger, stdout=False, stderr=False,
+                )
+        finally:
+            sys.stdout, sys.stderr = sys.__stdout__, sys.__stderr__
 
+    @mock.restore_logging()
     def test_logging_proxy(self):
-        with restore_logging():
-            logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
-                                       root=False)
-
-            with wrap_logger(logger) as sio:
-                p = LoggingProxy(logger, loglevel=logging.ERROR)
-                p.close()
-                p.write('foo')
-                self.assertNotIn('foo', sio.getvalue())
-                p.closed = False
-                p.write('foo')
-                self.assertIn('foo', sio.getvalue())
-                lines = ['baz', 'xuzzy']
-                p.writelines(lines)
-                for line in lines:
-                    self.assertIn(line, sio.getvalue())
-                p.flush()
-                p.close()
-                self.assertFalse(p.isatty())
-
-                with override_stdouts() as (stdout, stderr):
-                    with in_sighandler():
-                        p.write('foo')
-                        self.assertTrue(stderr.getvalue())
+        logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
+                                   root=False)
 
-    def test_logging_proxy_recurse_protection(self):
-        with restore_logging():
-            logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
-                                       root=False)
+        with mock.wrap_logger(logger) as sio:
             p = LoggingProxy(logger, loglevel=logging.ERROR)
-            p._thread.recurse_protection = True
-            try:
-                self.assertIsNone(p.write('FOOFO'))
-            finally:
-                p._thread.recurse_protection = False
+            p.close()
+            p.write('foo')
+            self.assertNotIn('foo', sio.getvalue())
+            p.closed = False
+            p.write('foo')
+            self.assertIn('foo', sio.getvalue())
+            lines = ['baz', 'xuzzy']
+            p.writelines(lines)
+            for line in lines:
+                self.assertIn(line, sio.getvalue())
+            p.flush()
+            p.close()
+            self.assertFalse(p.isatty())
+
+            with mock.stdouts() as (stdout, stderr):
+                with in_sighandler():
+                    p.write('foo')
+                    self.assertTrue(stderr.getvalue())
+
+    @mock.restore_logging()
+    def test_logging_proxy_recurse_protection(self):
+        logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
+                                   root=False)
+        p = LoggingProxy(logger, loglevel=logging.ERROR)
+        p._thread.recurse_protection = True
+        try:
+            self.assertIsNone(p.write('FOOFO'))
+        finally:
+            p._thread.recurse_protection = False
 
 
 class test_task_logger(test_default_logger):

+ 3 - 3
celery/tests/app/test_schedules.py

@@ -10,7 +10,7 @@ from celery.five import items
 from celery.schedules import (
     ParseException, crontab, crontab_parser, schedule, solar,
 )
-from celery.tests.case import AppCase, Mock, skip_unless_module, todo
+from celery.tests.case import AppCase, Mock, skip
 
 
 @contextmanager
@@ -23,7 +23,7 @@ def patch_crontab_nowfun(cls, retval):
         cls.nowfun = prev_nowfun
 
 
-@skip_unless_module('ephem')
+@skip.unless_module('ephem')
 class test_solar(AppCase):
 
     def setup(self):
@@ -735,7 +735,7 @@ class test_crontab_is_due(AppCase):
             self.assertTrue(due)
             self.assertEqual(remaining, 60.)
 
-    @todo('unstable test')
+    @skip.todo('unstable test')
     def test_monthly_moy_execution_is_not_due(self):
         with patch_crontab_nowfun(
                 self.monthly_moy, datetime(2013, 6, 28, 14, 30)):

+ 2 - 4
celery/tests/backends/test_amqp.py

@@ -14,9 +14,7 @@ from celery.five import Empty, Queue, range
 from celery.result import AsyncResult
 from celery.utils import uuid
 
-from celery.tests.case import (
-    AppCase, Mock, depends_on_current_app, sleepdeprived,
-)
+from celery.tests.case import AppCase, Mock, depends_on_current_app, mock
 
 
 class SomeClass(object):
@@ -112,7 +110,7 @@ class test_AMQPBackend(AppCase):
         b = self.create_backend(expires=timedelta(minutes=1))
         self.assertEqual(b.queue_arguments.get('x-expires'), 60 * 1000.0)
 
-    @sleepdeprived()
+    @mock.sleepdeprived()
     def test_store_result_retries(self):
         iterations = [0]
         stop_raising_at = [5]

+ 3 - 4
celery/tests/backends/test_base.py

@@ -25,9 +25,7 @@ from celery.result import result_from_tuple
 from celery.utils import uuid
 from celery.utils.functional import pass1
 
-from celery.tests.case import (
-    ANY, AppCase, Case, Mock, call, patch, skip_if_python3,
-)
+from celery.tests.case import ANY, AppCase, Case, Mock, call, patch, skip
 
 
 class wrapobject(object):
@@ -94,7 +92,8 @@ class test_BaseBackend_interface(AppCase):
 
 class test_exception_pickle(AppCase):
 
-    @skip_if_python3('does not support old style classes')
+    @skip.if_python3(reason='does not support old style classes')
+    @skip.if_pypy()
     def test_oldstyle(self):
         self.assertTrue(fnpe(Oldstyle()))
 

+ 15 - 17
celery/tests/backends/test_cache.py

@@ -15,9 +15,7 @@ from celery.exceptions import ImproperlyConfigured
 from celery.five import items, module_name_t, string, text_t
 from celery.utils import uuid
 
-from celery.tests.case import (
-    AppCase, Mock, override_stdouts, mask_modules, patch, reset_modules,
-)
+from celery.tests.case import AppCase, Mock, mock, patch
 
 PY3 = sys.version_info[0] == 3
 
@@ -136,8 +134,8 @@ class test_CacheBackend(AppCase):
         b = CacheBackend(backend=backend, app=self.app)
         self.assertEqual(b.as_uri(), backend)
 
-    @override_stdouts
-    def test_regression_worker_startup_info(self):
+    @mock.stdouts
+    def test_regression_worker_startup_info(self, stdout, stderr):
         self.app.conf.result_backend = (
             'cache+memcached://127.0.0.1:11211;127.0.0.2:11211;127.0.0.3/'
         )
@@ -197,7 +195,7 @@ class test_get_best_memcache(AppCase, MockCacheMixin):
 
     def test_pylibmc(self):
         with self.mock_pylibmc():
-            with reset_modules('celery.backends.cache'):
+            with mock.reset_modules('celery.backends.cache'):
                 from celery.backends import cache
                 cache._imp = [None]
                 self.assertEqual(cache.get_best_memcache()[0].__module__,
@@ -205,16 +203,16 @@ class test_get_best_memcache(AppCase, MockCacheMixin):
 
     def test_memcache(self):
         with self.mock_memcache():
-            with reset_modules('celery.backends.cache'):
-                with mask_modules('pylibmc'):
+            with mock.reset_modules('celery.backends.cache'):
+                with mock.mask_modules('pylibmc'):
                     from celery.backends import cache
                     cache._imp = [None]
                     self.assertEqual(cache.get_best_memcache()[0]().__module__,
                                      'memcache')
 
     def test_no_implementations(self):
-        with mask_modules('pylibmc', 'memcache'):
-            with reset_modules('celery.backends.cache'):
+        with mock.mask_modules('pylibmc', 'memcache'):
+            with mock.reset_modules('celery.backends.cache'):
                 from celery.backends import cache
                 cache._imp = [None]
                 with self.assertRaises(ImproperlyConfigured):
@@ -222,7 +220,7 @@ class test_get_best_memcache(AppCase, MockCacheMixin):
 
     def test_cached(self):
         with self.mock_pylibmc():
-            with reset_modules('celery.backends.cache'):
+            with mock.reset_modules('celery.backends.cache'):
                 from celery.backends import cache
                 cache._imp = [None]
                 cache.get_best_memcache()[0](behaviors={'foo': 'bar'})
@@ -240,8 +238,8 @@ class test_memcache_key(AppCase, MockCacheMixin):
 
     def test_memcache_unicode_key(self):
         with self.mock_memcache():
-            with reset_modules('celery.backends.cache'):
-                with mask_modules('pylibmc'):
+            with mock.reset_modules('celery.backends.cache'):
+                with mock.mask_modules('pylibmc'):
                     from celery.backends import cache
                     cache._imp = [None]
                     task_id, result = string(uuid()), 42
@@ -251,8 +249,8 @@ class test_memcache_key(AppCase, MockCacheMixin):
 
     def test_memcache_bytes_key(self):
         with self.mock_memcache():
-            with reset_modules('celery.backends.cache'):
-                with mask_modules('pylibmc'):
+            with mock.reset_modules('celery.backends.cache'):
+                with mock.mask_modules('pylibmc'):
                     from celery.backends import cache
                     cache._imp = [None]
                     task_id, result = str_to_bytes(uuid()), 42
@@ -261,7 +259,7 @@ class test_memcache_key(AppCase, MockCacheMixin):
                     self.assertEqual(b.get_result(task_id), result)
 
     def test_pylibmc_unicode_key(self):
-        with reset_modules('celery.backends.cache'):
+        with mock.reset_modules('celery.backends.cache'):
             with self.mock_pylibmc():
                 from celery.backends import cache
                 cache._imp = [None]
@@ -271,7 +269,7 @@ class test_memcache_key(AppCase, MockCacheMixin):
                 self.assertEqual(b.get_result(task_id), result)
 
     def test_pylibmc_bytes_key(self):
-        with reset_modules('celery.backends.cache'):
+        with mock.reset_modules('celery.backends.cache'):
             with self.mock_pylibmc():
                 from celery.backends import cache
                 cache._imp = [None]

+ 130 - 140
celery/tests/backends/test_cassandra.py

@@ -6,13 +6,12 @@ from datetime import datetime
 from celery import states
 from celery.exceptions import ImproperlyConfigured
 from celery.utils.objects import Bunch
-from celery.tests.case import (
-    AppCase, Mock, mock_module, depends_on_current_app
-)
+from celery.tests.case import AppCase, Mock, depends_on_current_app, mock
 
 CASSANDRA_MODULES = ['cassandra', 'cassandra.auth', 'cassandra.cluster']
 
 
+@mock.module(*CASSANDRA_MODULES)
 class test_CassandraBackend(AppCase):
 
     def setup(self):
@@ -22,174 +21,165 @@ class test_CassandraBackend(AppCase):
             cassandra_table='task_results',
         )
 
-    def test_init_no_cassandra(self):
-        """should raise ImproperlyConfigured when no python-driver
-        installed."""
-        with mock_module(*CASSANDRA_MODULES):
-            from celery.backends import cassandra as mod
-            prev, mod.cassandra = mod.cassandra, None
-            try:
-                with self.assertRaises(ImproperlyConfigured):
-                    mod.CassandraBackend(app=self.app)
-            finally:
-                mod.cassandra = prev
-
-    def test_init_with_and_without_LOCAL_QUROM(self):
-        with mock_module(*CASSANDRA_MODULES):
-            from celery.backends import cassandra as mod
-            mod.cassandra = Mock()
-
-            cons = mod.cassandra.ConsistencyLevel = Bunch(
-                LOCAL_QUORUM='foo',
-            )
-
-            self.app.conf.cassandra_read_consistency = 'LOCAL_FOO'
-            self.app.conf.cassandra_write_consistency = 'LOCAL_FOO'
-
-            mod.CassandraBackend(app=self.app)
-            cons.LOCAL_FOO = 'bar'
-            mod.CassandraBackend(app=self.app)
-
-            # no servers raises ImproperlyConfigured
+    def test_init_no_cassandra(self, *modules):
+        # should raise ImproperlyConfigured when no python-driver
+        # installed.
+        from celery.backends import cassandra as mod
+        prev, mod.cassandra = mod.cassandra, None
+        try:
             with self.assertRaises(ImproperlyConfigured):
-                self.app.conf.cassandra_servers = None
-                mod.CassandraBackend(
-                    app=self.app, keyspace='b', column_family='c',
-                )
+                mod.CassandraBackend(app=self.app)
+        finally:
+            mod.cassandra = prev
 
-    @depends_on_current_app
-    def test_reduce(self):
-        with mock_module(*CASSANDRA_MODULES):
-            from celery.backends.cassandra import CassandraBackend
-            self.assertTrue(loads(dumps(CassandraBackend(app=self.app))))
+    def test_init_with_and_without_LOCAL_QUROM(self, *modules):
+        from celery.backends import cassandra as mod
+        mod.cassandra = Mock()
 
-    def test_get_task_meta_for(self):
-        with mock_module(*CASSANDRA_MODULES):
-            from celery.backends import cassandra as mod
-            mod.cassandra = Mock()
+        cons = mod.cassandra.ConsistencyLevel = Bunch(
+            LOCAL_QUORUM='foo',
+        )
 
-            x = mod.CassandraBackend(app=self.app)
-            x._connection = True
-            session = x._session = Mock()
-            execute = session.execute = Mock()
-            execute.return_value = [
-                [states.SUCCESS, '1', datetime.now(), b'', b'']
-            ]
-            x.decode = Mock()
-            meta = x._get_task_meta_for('task_id')
-            self.assertEqual(meta['status'], states.SUCCESS)
-
-            x._session.execute.return_value = []
-            meta = x._get_task_meta_for('task_id')
-            self.assertEqual(meta['status'], states.PENDING)
-
-    def test_store_result(self):
-        with mock_module(*CASSANDRA_MODULES):
-            from celery.backends import cassandra as mod
-            mod.cassandra = Mock()
+        self.app.conf.cassandra_read_consistency = 'LOCAL_FOO'
+        self.app.conf.cassandra_write_consistency = 'LOCAL_FOO'
 
-            x = mod.CassandraBackend(app=self.app)
-            x._connection = True
-            session = x._session = Mock()
-            session.execute = Mock()
-            x._store_result('task_id', 'result', states.SUCCESS)
+        mod.CassandraBackend(app=self.app)
+        cons.LOCAL_FOO = 'bar'
+        mod.CassandraBackend(app=self.app)
 
-    def test_process_cleanup(self):
-        with mock_module(*CASSANDRA_MODULES):
-            from celery.backends import cassandra as mod
-            x = mod.CassandraBackend(app=self.app)
-            x.process_cleanup()
+        # no servers raises ImproperlyConfigured
+        with self.assertRaises(ImproperlyConfigured):
+            self.app.conf.cassandra_servers = None
+            mod.CassandraBackend(
+                app=self.app, keyspace='b', column_family='c',
+            )
 
-            self.assertIsNone(x._connection)
-            self.assertIsNone(x._session)
+    @depends_on_current_app
+    def test_reduce(self, *modules):
+        from celery.backends.cassandra import CassandraBackend
+        self.assertTrue(loads(dumps(CassandraBackend(app=self.app))))
+
+    def test_get_task_meta_for(self, *modules):
+        from celery.backends import cassandra as mod
+        mod.cassandra = Mock()
+
+        x = mod.CassandraBackend(app=self.app)
+        x._connection = True
+        session = x._session = Mock()
+        execute = session.execute = Mock()
+        execute.return_value = [
+            [states.SUCCESS, '1', datetime.now(), b'', b'']
+        ]
+        x.decode = Mock()
+        meta = x._get_task_meta_for('task_id')
+        self.assertEqual(meta['status'], states.SUCCESS)
+
+        x._session.execute.return_value = []
+        meta = x._get_task_meta_for('task_id')
+        self.assertEqual(meta['status'], states.PENDING)
+
+    def test_store_result(self, *modules):
+        from celery.backends import cassandra as mod
+        mod.cassandra = Mock()
+
+        x = mod.CassandraBackend(app=self.app)
+        x._connection = True
+        session = x._session = Mock()
+        session.execute = Mock()
+        x._store_result('task_id', 'result', states.SUCCESS)
+
+    def test_process_cleanup(self, *modules):
+        from celery.backends import cassandra as mod
+        x = mod.CassandraBackend(app=self.app)
+        x.process_cleanup()
+
+        self.assertIsNone(x._connection)
+        self.assertIsNone(x._session)
 
     def test_timeouting_cluster(self):
-        """Tests behaviour when Cluster.connect raises
-        cassandra.OperationTimedOut."""
-        with mock_module(*CASSANDRA_MODULES):
-            from celery.backends import cassandra as mod
+        # Tests behaviour when Cluster.connect raises
+        # cassandra.OperationTimedOut.
+        from celery.backends import cassandra as mod
 
-            class OTOExc(Exception):
-                pass
+        class OTOExc(Exception):
+            pass
 
-            class VeryFaultyCluster(object):
-                def __init__(self, *args, **kwargs):
-                    pass
+        class VeryFaultyCluster(object):
+            def __init__(self, *args, **kwargs):
+                pass
 
-                def connect(self, *args, **kwargs):
-                    raise OTOExc()
+            def connect(self, *args, **kwargs):
+                raise OTOExc()
 
-                def shutdown(self):
-                    pass
+            def shutdown(self):
+                pass
 
-            mod.cassandra = Mock()
-            mod.cassandra.OperationTimedOut = OTOExc
-            mod.cassandra.cluster = Mock()
-            mod.cassandra.cluster.Cluster = VeryFaultyCluster
+        mod.cassandra = Mock()
+        mod.cassandra.OperationTimedOut = OTOExc
+        mod.cassandra.cluster = Mock()
+        mod.cassandra.cluster.Cluster = VeryFaultyCluster
 
-            x = mod.CassandraBackend(app=self.app)
+        x = mod.CassandraBackend(app=self.app)
 
-            with self.assertRaises(OTOExc):
-                x._store_result('task_id', 'result', states.SUCCESS)
-            self.assertIsNone(x._connection)
-            self.assertIsNone(x._session)
+        with self.assertRaises(OTOExc):
+            x._store_result('task_id', 'result', states.SUCCESS)
+        self.assertIsNone(x._connection)
+        self.assertIsNone(x._session)
 
-            x.process_cleanup()  # should not raise
+        x.process_cleanup()  # should not raise
 
     def test_please_free_memory(self):
-        """Ensure that Cluster object IS shut down."""
-        with mock_module(*CASSANDRA_MODULES):
-            from celery.backends import cassandra as mod
+        # Ensure that Cluster object IS shut down.
+        from celery.backends import cassandra as mod
 
-            class RAMHoggingCluster(object):
+        class RAMHoggingCluster(object):
 
-                objects_alive = 0
+            objects_alive = 0
 
-                def __init__(self, *args, **kwargs):
-                    pass
+            def __init__(self, *args, **kwargs):
+                pass
 
-                def connect(self, *args, **kwargs):
-                    RAMHoggingCluster.objects_alive += 1
-                    return Mock()
+            def connect(self, *args, **kwargs):
+                RAMHoggingCluster.objects_alive += 1
+                return Mock()
 
-                def shutdown(self):
-                    RAMHoggingCluster.objects_alive -= 1
+            def shutdown(self):
+                RAMHoggingCluster.objects_alive -= 1
 
-            mod.cassandra = Mock()
+        mod.cassandra = Mock()
 
-            mod.cassandra.cluster = Mock()
-            mod.cassandra.cluster.Cluster = RAMHoggingCluster
+        mod.cassandra.cluster = Mock()
+        mod.cassandra.cluster.Cluster = RAMHoggingCluster
 
-            for x in range(0, 10):
-                x = mod.CassandraBackend(app=self.app)
-                x._store_result('task_id', 'result', states.SUCCESS)
-                x.process_cleanup()
+        for x in range(0, 10):
+            x = mod.CassandraBackend(app=self.app)
+            x._store_result('task_id', 'result', states.SUCCESS)
+            x.process_cleanup()
 
-            self.assertEquals(RAMHoggingCluster.objects_alive, 0)
+        self.assertEquals(RAMHoggingCluster.objects_alive, 0)
 
     def test_auth_provider(self):
-        """Ensure valid auth_provider works properly, and invalid one raises
-        ImproperlyConfigured exception."""
+        # Ensure valid auth_provider works properly, and invalid one raises
+        # ImproperlyConfigured exception.
+        from celery.backends import cassandra as mod
+
         class DummyAuth(object):
             ValidAuthProvider = Mock()
 
-        with mock_module(*CASSANDRA_MODULES):
-            from celery.backends import cassandra as mod
-
-            mod.cassandra = Mock()
-            mod.cassandra.auth = DummyAuth
-
-            # Valid auth_provider
-            self.app.conf.cassandra_auth_provider = 'ValidAuthProvider'
-            self.app.conf.cassandra_auth_kwargs = {
-                'username': 'stuff'
-            }
+        mod.cassandra = Mock()
+        mod.cassandra.auth = DummyAuth
+
+        # Valid auth_provider
+        self.app.conf.cassandra_auth_provider = 'ValidAuthProvider'
+        self.app.conf.cassandra_auth_kwargs = {
+            'username': 'stuff'
+        }
+        mod.CassandraBackend(app=self.app)
+
+        # Invalid auth_provider
+        self.app.conf.cassandra_auth_provider = 'SpiderManAuth'
+        self.app.conf.cassandra_auth_kwargs = {
+            'username': 'Jack'
+        }
+        with self.assertRaises(ImproperlyConfigured):
             mod.CassandraBackend(app=self.app)
-
-            # Invalid auth_provider
-            self.app.conf.cassandra_auth_provider = 'SpiderManAuth'
-            self.app.conf.cassandra_auth_kwargs = {
-                'username': 'Jack'
-            }
-            with self.assertRaises(ImproperlyConfigured):
-                mod.CassandraBackend(app=self.app)

+ 2 - 4
celery/tests/backends/test_couchbase.py

@@ -8,9 +8,7 @@ from celery.backends import couchbase as module
 from celery.backends.couchbase import CouchBaseBackend
 from celery.exceptions import ImproperlyConfigured
 from celery import backends
-from celery.tests.case import (
-    AppCase, MagicMock, Mock, patch, sentinel, skip_unless_module,
-)
+from celery.tests.case import AppCase, MagicMock, Mock, patch, sentinel, skip
 
 try:
     import couchbase
@@ -20,7 +18,7 @@ except ImportError:
 COUCHBASE_BUCKET = 'celery_bucket'
 
 
-@skip_unless_module('couchbase')
+@skip.unless_module('couchbase')
 class test_CouchBaseBackend(AppCase):
 
     def setup(self):

+ 2 - 4
celery/tests/backends/test_couchdb.py

@@ -4,9 +4,7 @@ from celery.backends import couchdb as module
 from celery.backends.couchdb import CouchBackend
 from celery.exceptions import ImproperlyConfigured
 from celery import backends
-from celery.tests.case import (
-    AppCase, Mock, patch, sentinel, skip_unless_module,
-)
+from celery.tests.case import AppCase, Mock, patch, sentinel, skip
 
 try:
     import pycouchdb
@@ -16,7 +14,7 @@ except ImportError:
 COUCHDB_CONTAINER = 'celery_container'
 
 
-@skip_unless_module('pycouchdb')
+@skip.unless_module('pycouchdb')
 class test_CouchBackend(AppCase):
 
     def setup(self):

+ 6 - 12
celery/tests/backends/test_database.py

@@ -9,13 +9,7 @@ from celery.exceptions import ImproperlyConfigured
 from celery.utils import uuid
 
 from celery.tests.case import (
-    AppCase,
-    Mock,
-    depends_on_current_app,
-    patch,
-    skip_if_pypy,
-    skip_if_jython,
-    skip_unless_module,
+    AppCase, Mock, depends_on_current_app, patch, skip,
 )
 
 try:
@@ -38,7 +32,7 @@ class SomeClass(object):
         self.data = data
 
 
-@skip_unless_module('sqlalchemy')
+@skip.unless_module('sqlalchemy')
 class test_session_cleanup(AppCase):
 
     def test_context(self):
@@ -56,9 +50,9 @@ class test_session_cleanup(AppCase):
         session.close.assert_called_with()
 
 
-@skip_unless_module('sqlalchemy')
-@skip_if_pypy()
-@skip_if_jython()
+@skip.unless_module('sqlalchemy')
+@skip.if_pypy()
+@skip.if_jython()
 class test_DatabaseBackend(AppCase):
 
     def setup(self):
@@ -214,7 +208,7 @@ class test_DatabaseBackend(AppCase):
         self.assertIn('foo', repr(TaskSet('foo', None)))
 
 
-@skip_unless_module('sqlalchemy')
+@skip.unless_module('sqlalchemy')
 class test_SessionManager(AppCase):
 
     def test_after_fork(self):

+ 2 - 2
celery/tests/backends/test_elasticsearch.py

@@ -5,10 +5,10 @@ from celery.backends import elasticsearch as module
 from celery.backends.elasticsearch import ElasticsearchBackend
 from celery.exceptions import ImproperlyConfigured
 
-from celery.tests.case import AppCase, Mock, sentinel, skip_unless_module
+from celery.tests.case import AppCase, Mock, sentinel, skip
 
 
-@skip_unless_module('elasticsearch')
+@skip.unless_module('elasticsearch')
 class test_ElasticsearchBackend(AppCase):
 
     def setup(self):

+ 2 - 2
celery/tests/backends/test_filesystem.py

@@ -10,10 +10,10 @@ from celery.backends.filesystem import FilesystemBackend
 from celery.exceptions import ImproperlyConfigured
 from celery.utils import uuid
 
-from celery.tests.case import AppCase, skip_if_win32
+from celery.tests.case import AppCase, skip
 
 
-@skip_if_win32()
+@skip.if_win32()
 class test_FilesystemBackend(AppCase):
 
     def setup(self):

+ 6 - 7
celery/tests/backends/test_mongodb.py

@@ -12,9 +12,8 @@ from celery.backends import mongodb as module
 from celery.backends.mongodb import InvalidDocument, MongoBackend
 from celery.exceptions import ImproperlyConfigured
 from celery.tests.case import (
-    AppCase, MagicMock, Mock, ANY,
-    depends_on_current_app, override_stdouts, patch, sentinel,
-    skip_unless_module,
+    ANY, AppCase, MagicMock, Mock,
+    mock, depends_on_current_app, patch, sentinel, skip,
 )
 
 COLLECTION = 'taskmeta_celery'
@@ -28,7 +27,7 @@ MONGODB_COLLECTION = 'collection1'
 MONGODB_GROUP_COLLECTION = 'group_collection1'
 
 
-@skip_unless_module('pymongo')
+@skip.unless_module('pymongo')
 class test_MongoBackend(AppCase):
 
     default_url = 'mongodb://uuuu:pwpw@hostname.dom/database'
@@ -407,8 +406,8 @@ class test_MongoBackend(AppCase):
         backend = MongoBackend(app=self.app, url=self.replica_set_url)
         self.assertEqual(backend.as_uri(), self.sanitized_replica_set_url)
 
-    @override_stdouts
-    def test_regression_worker_startup_info(self):
+    @mock.stdouts
+    def test_regression_worker_startup_info(self, stdout, stderr):
         self.app.conf.result_backend = (
             'mongodb://user:password@host0.com:43437,host1.com:43437'
             '/work4us?replicaSet=rs&ssl=true'
@@ -418,7 +417,7 @@ class test_MongoBackend(AppCase):
         self.assertTrue(worker.startup_info())
 
 
-@skip_unless_module('pymongo')
+@skip.unless_module('pymongo')
 class test_MongoBackend_no_mock(AppCase):
 
     def test_encode_decode(self):

+ 4 - 4
celery/tests/backends/test_redis.py

@@ -13,8 +13,8 @@ from celery.datastructures import AttributeDict
 from celery.exceptions import ChordError, ImproperlyConfigured
 
 from celery.tests.case import (
-    ANY, AppCase, ContextMock, Mock, MockCallbacks,
-    call, depends_on_current_app, patch, skip_unless_module,
+    ANY, AppCase, ContextMock, Mock, mock,
+    call, depends_on_current_app, patch, skip,
 )
 
 
@@ -58,7 +58,7 @@ class Pipeline(object):
         return [step(*a, **kw) for step, a, kw in self.steps]
 
 
-class Redis(MockCallbacks):
+class Redis(mock.MockCallbacks):
     Connection = Connection
     Pipeline = Pipeline
 
@@ -142,7 +142,7 @@ class test_RedisBackend(AppCase):
         self.b = self.Backend(app=self.app)
 
     @depends_on_current_app
-    @skip_unless_module('redis')
+    @skip.unless_module('redis')
     def test_reduce(self):
         from celery.backends.redis import RedisBackend
         x = RedisBackend(app=self.app)

+ 2 - 5
celery/tests/backends/test_riak.py

@@ -5,15 +5,12 @@ from __future__ import absolute_import, unicode_literals
 from celery.backends import riak as module
 from celery.backends.riak import RiakBackend
 from celery.exceptions import ImproperlyConfigured
-from celery.tests.case import (
-    AppCase, MagicMock, Mock, patch, sentinel, skip_unless_module,
-)
-
+from celery.tests.case import AppCase, MagicMock, Mock, patch, sentinel, skip
 
 RIAK_BUCKET = 'riak_bucket'
 
 
-@skip_unless_module('riak')
+@skip.unless_module('riak')
 class test_RiakBackend(AppCase):
 
     def setup(self):

+ 9 - 9
celery/tests/bin/test_base.py

@@ -12,7 +12,7 @@ from celery.five import module_name_t
 from celery.utils.objects import Bunch
 
 from celery.tests.case import (
-    AppCase, Mock, depends_on_current_app, override_stdouts, patch,
+    AppCase, Mock, depends_on_current_app, mock, patch,
 )
 
 
@@ -144,14 +144,14 @@ class test_Command(AppCase):
         self.assertDictContainsSubset({'foo': 'bar', 'prog_name': 'foo'},
                                       kwargs2)
 
-    def test_with_bogus_args(self):
-        with override_stdouts() as (_, stderr):
-            cmd = MockCommand(app=self.app)
-            cmd.supports_args = False
-            with self.assertRaises(SystemExit):
-                cmd.execute_from_commandline(argv=['--bogus'])
-            self.assertTrue(stderr.getvalue())
-            self.assertIn('Unrecognized', stderr.getvalue())
+    @mock.stdouts
+    def test_with_bogus_args(self, _, stderr):
+        cmd = MockCommand(app=self.app)
+        cmd.supports_args = False
+        with self.assertRaises(SystemExit):
+            cmd.execute_from_commandline(argv=['--bogus'])
+        self.assertTrue(stderr.getvalue())
+        self.assertIn('Unrecognized', stderr.getvalue())
 
     def test_with_custom_config_module(self):
         prev = os.environ.pop('CELERY_CONFIG_MODULE', None)

+ 26 - 24
celery/tests/bin/test_beat.py

@@ -10,8 +10,7 @@ from celery import platforms
 from celery.bin import beat as beat_bin
 from celery.apps import beat as beatapp
 
-from celery.tests.case import AppCase, Mock, patch, restore_logging
-from kombu.tests.case import redirect_stdouts
+from celery.tests.case import AppCase, Mock, mock, patch
 
 
 class MockedShelveModule(object):
@@ -113,32 +112,35 @@ class test_Beat(AppCase):
         self.assertTrue(MockService.in_sync)
         MockService.in_sync = False
 
+    @mock.restore_logging()
     def test_setup_logging(self):
-        with restore_logging():
-            try:
-                # py3k
-                delattr(sys.stdout, 'logger')
-            except AttributeError:
-                pass
-            b = beatapp.Beat(app=self.app, redirect_stdouts=False)
-            b.redirect_stdouts = False
-            b.app.log.already_setup = False
-            b.setup_logging()
-            with self.assertRaises(AttributeError):
-                sys.stdout.logger
-
-    @redirect_stdouts
+        try:
+            # py3k
+            delattr(sys.stdout, 'logger')
+        except AttributeError:
+            pass
+        b = beatapp.Beat(app=self.app, redirect_stdouts=False)
+        b.redirect_stdouts = False
+        b.app.log.already_setup = False
+        b.setup_logging()
+        with self.assertRaises(AttributeError):
+            sys.stdout.logger
+
+    import sys
+    orig_stdout = sys.__stdout__
+
     @patch('celery.apps.beat.logger')
+    @mock.restore_logging()
+    @mock.stdouts
     def test_logs_errors(self, logger, stdout, stderr):
-        with restore_logging():
-            b = MockBeat3(
-                app=self.app, redirect_stdouts=False, socket_timeout=None,
-            )
-            b.start_scheduler()
-            self.assertTrue(logger.critical.called)
-
-    @redirect_stdouts
+        b = MockBeat3(
+            app=self.app, redirect_stdouts=False, socket_timeout=None,
+        )
+        b.start_scheduler()
+        self.assertTrue(logger.critical.called)
+
     @patch('celery.platforms.create_pidlock')
+    @mock.stdouts
     def test_use_pidfile(self, create_pidlock, stdout, stderr):
         b = MockBeat2(app=self.app, pidfile='pidfilelockfilepid',
                       socket_timeout=None, redirect_stdouts=False)

+ 2 - 2
celery/tests/bin/test_celeryd_detach.py

@@ -7,7 +7,7 @@ from celery.bin.celeryd_detach import (
     main,
 )
 
-from celery.tests.case import AppCase, Mock, override_stdouts, patch
+from celery.tests.case import AppCase, Mock, mock, patch
 
 
 if not IS_WINDOWS:
@@ -76,7 +76,7 @@ class test_PartialOptionParser(AppCase):
         ])
         self.assertEqual(options.pidfile, '/var/pid/foo.pid')
 
-        with override_stdouts():
+        with mock.stdouts():
             with self.assertRaises(SystemExit):
                 p.parse_args(['--logfile'])
             p.get_option('--logfile').nargs = 2

+ 2 - 2
celery/tests/bin/test_events.py

@@ -2,7 +2,7 @@ from __future__ import absolute_import, unicode_literals
 
 from celery.bin import events
 
-from celery.tests.case import AppCase, patch, _old_patch, skip_unless_module
+from celery.tests.case import AppCase, patch, _old_patch, skip
 
 
 class MockCommand(object):
@@ -29,7 +29,7 @@ class test_events(AppCase):
         self.assertEqual(self.ev.run(dump=True), 'me dumper, you?')
         self.assertIn('celery events:dump', proctitle.last[0])
 
-    @skip_unless_module('curses')
+    @skip.unless_module('curses', import_errors=(ImportError, OSError))
     def test_run_top(self):
         @_old_patch('celery.events.cursesmon', 'evtop',
                     lambda **kw: 'me top, you?')

+ 2 - 2
celery/tests/bin/test_multi.py

@@ -17,7 +17,7 @@ from celery.bin.multi import (
 )
 from celery.five import WhateverIO
 
-from celery.tests.case import AppCase, Mock, patch, skip_unless_symbol
+from celery.tests.case import AppCase, Mock, patch, skip
 
 
 class test_functions(AppCase):
@@ -265,7 +265,7 @@ class test_MultiTool(AppCase):
 
         )
 
-    @skip_unless_symbol('signal.SIGKILL')
+    @skip.unless_symbol('signal.SIGKILL')
     def test_kill(self):
         self.t.getpids = Mock()
         self.t.getpids.return_value = [

+ 49 - 59
celery/tests/bin/test_worker.py

@@ -18,17 +18,7 @@ from celery.exceptions import (
 from celery.platforms import EX_FAILURE, EX_OK
 from celery.worker import state
 
-from celery.tests.case import (
-    AppCase,
-    Mock,
-    override_stdouts,
-    patch,
-    skip_if_jython,
-    skip_if_pypy,
-    skip_if_win32,
-    skip_unless_module,
-    skip_unless_symbol,
-)
+from celery.tests.case import AppCase, Mock, mock, patch, skip
 
 
 class WorkerAppCase(AppCase):
@@ -47,14 +37,14 @@ class Worker(cd.Worker):
 class test_Worker(WorkerAppCase):
     Worker = Worker
 
-    @override_stdouts
-    def test_queues_string(self):
+    @mock.stdouts
+    def test_queues_string(self, stdout, stderr):
         w = self.app.Worker()
         w.setup_queues('foo,bar,baz')
         self.assertIn('foo', self.app.amqp.queues)
 
-    @override_stdouts
-    def test_cpu_count(self):
+    @mock.stdouts
+    def test_cpu_count(self, stdout, stderr):
         with patch('celery.worker.cpu_count') as cpu_count:
             cpu_count.side_effect = NotImplementedError()
             w = self.app.Worker(concurrency=None)
@@ -62,8 +52,8 @@ class test_Worker(WorkerAppCase):
         w = self.app.Worker(concurrency=5)
         self.assertEqual(w.concurrency, 5)
 
-    @override_stdouts
-    def test_windows_B_option(self):
+    @mock.stdouts
+    def test_windows_B_option(self, stdout, stderr):
         self.app.IS_WINDOWS = True
         with self.assertRaises(SystemExit):
             worker(app=self.app).run(beat=True)
@@ -94,8 +84,8 @@ class test_Worker(WorkerAppCase):
                 x.maybe_detach(['--detach'])
             self.assertTrue(detached.called)
 
-    @override_stdouts
-    def test_invalid_loglevel_gives_error(self):
+    @mock.stdouts
+    def test_invalid_loglevel_gives_error(self, stdout, stderr):
         x = worker(app=self.app)
         with self.assertRaises(SystemExit):
             x.run(loglevel='GRIM_REAPER')
@@ -118,13 +108,13 @@ class test_Worker(WorkerAppCase):
         worker.loglevel = logging.INFO
         self.assertTrue(worker.extra_info())
 
-    @override_stdouts
-    def test_loglevel_string(self):
+    @mock.stdouts
+    def test_loglevel_string(self, stdout, stderr):
         worker = self.Worker(app=self.app, loglevel='INFO')
         self.assertEqual(worker.loglevel, logging.INFO)
 
-    @override_stdouts
-    def test_run_worker(self):
+    @mock.stdouts
+    def test_run_worker(self, stdout, stderr):
         handlers = {}
 
         class Signals(platforms.Signals):
@@ -151,8 +141,8 @@ class test_Worker(WorkerAppCase):
         finally:
             platforms.signals = p
 
-    @override_stdouts
-    def test_startup_info(self):
+    @mock.stdouts
+    def test_startup_info(self, stdout, stderr):
         worker = self.Worker(app=self.app)
         worker.on_start()
         self.assertTrue(worker.startup_info())
@@ -189,19 +179,19 @@ class test_Worker(WorkerAppCase):
         finally:
             cd.ARTLINES = prev
 
-    @override_stdouts
-    def test_run(self):
+    @mock.stdouts
+    def test_run(self, stdout, stderr):
         self.Worker(app=self.app).on_start()
         self.Worker(app=self.app, purge=True).on_start()
         worker = self.Worker(app=self.app)
         worker.on_start()
 
-    @override_stdouts
-    def test_purge_messages(self):
+    @mock.stdouts
+    def test_purge_messages(self, stdout, stderr):
         self.Worker(app=self.app).purge_messages()
 
-    @override_stdouts
-    def test_init_queues(self):
+    @mock.stdouts
+    def test_init_queues(self, stdout, stderr):
         app = self.app
         c = app.conf
         app.amqp.queues = app.amqp.Queues({
@@ -231,8 +221,8 @@ class test_Worker(WorkerAppCase):
             app.amqp.queues['image'],
         )
 
-    @override_stdouts
-    def test_autoscale_argument(self):
+    @mock.stdouts
+    def test_autoscale_argument(self, stdout, stderr):
         worker1 = self.Worker(app=self.app, autoscale='10,3')
         self.assertListEqual(worker1.autoscale, [10, 3])
         worker2 = self.Worker(app=self.app, autoscale='10')
@@ -247,17 +237,17 @@ class test_Worker(WorkerAppCase):
         self.assertListEqual(worker2.include, ['os', 'sys'])
         self.Worker(app=self.app, include=['os', 'sys'])
 
-    @override_stdouts
-    def test_unknown_loglevel(self):
+    @mock.stdouts
+    def test_unknown_loglevel(self, stdout, stderr):
         with self.assertRaises(SystemExit):
             worker(app=self.app).run(loglevel='ALIEN')
         worker1 = self.Worker(app=self.app, loglevel=0xFFFF)
         self.assertEqual(worker1.loglevel, 0xFFFF)
 
-    @skip_if_win32
-    @override_stdouts
     @patch('os._exit')
-    def test_warns_if_running_as_privileged_user(self, _exit):
+    @skip.if_win32()
+    @mock.stdouts
+    def test_warns_if_running_as_privileged_user(self, _exit, stdout, stderr):
         with patch('os.getuid') as getuid:
             getuid.return_value = 0
             self.app.conf.accept_content = ['pickle']
@@ -281,14 +271,14 @@ class test_Worker(WorkerAppCase):
                 worker = self.Worker(app=self.app)
                 worker.on_start()
 
-    @override_stdouts
-    def test_redirect_stdouts(self):
+    @mock.stdouts
+    def test_redirect_stdouts(self, stdout, stderr):
         self.Worker(app=self.app, redirect_stdouts=False)
         with self.assertRaises(AttributeError):
             sys.stdout.logger
 
-    @override_stdouts
-    def test_on_start_custom_logging(self):
+    @mock.stdouts
+    def test_on_start_custom_logging(self, stdout, stderr):
         self.app.log.redirect_stdouts = Mock()
         worker = self.Worker(app=self.app, redirect_stoutds=True)
         worker._custom_logging = True
@@ -306,8 +296,8 @@ class test_Worker(WorkerAppCase):
         finally:
             self.app.log.setup = prev
 
-    @override_stdouts
-    def test_startup_info_pool_is_str(self):
+    @mock.stdouts
+    def test_startup_info_pool_is_str(self, stdout, stderr):
         worker = self.Worker(app=self.app, redirect_stdouts=False)
         worker.pool_cls = 'foo'
         worker.startup_info()
@@ -329,8 +319,8 @@ class test_Worker(WorkerAppCase):
         finally:
             signals.setup_logging.disconnect(on_logging_setup)
 
-    @override_stdouts
-    def test_platform_tweaks_osx(self):
+    @mock.stdouts
+    def test_platform_tweaks_osx(self, stdout, stderr):
 
         class OSXWorker(Worker):
             proxy_workaround_installed = False
@@ -357,8 +347,8 @@ class test_Worker(WorkerAppCase):
         finally:
             cd.install_HUP_not_supported_handler = prev
 
-    @override_stdouts
-    def test_general_platform_tweaks(self):
+    @mock.stdouts
+    def test_general_platform_tweaks(self, stdout, stderr):
 
         restart_worker_handler_installed = [False]
 
@@ -378,8 +368,8 @@ class test_Worker(WorkerAppCase):
         finally:
             cd.install_worker_restart_handler = prev
 
-    @override_stdouts
-    def test_on_consumer_ready(self):
+    @mock.stdouts
+    def test_on_consumer_ready(self, stdout, stderr):
         worker_ready_sent = [False]
 
         @signals.worker_ready.connect
@@ -390,13 +380,13 @@ class test_Worker(WorkerAppCase):
         self.assertTrue(worker_ready_sent[0])
 
 
-@override_stdouts
+@mock.stdouts
 class test_funs(WorkerAppCase):
 
     def test_active_thread_count(self):
         self.assertTrue(cd.active_thread_count())
 
-    @skip_unless_module('setproctitle')
+    @skip.unless_module('setproctitle')
     def test_set_process_status(self):
         worker = Worker(app=self.app, hostname='xyzza')
         prev1, sys.argv = sys.argv, ['Arg0']
@@ -435,7 +425,7 @@ class test_funs(WorkerAppCase):
             sys.argv = s
 
 
-@override_stdouts
+@mock.stdouts
 class test_signal_handlers(WorkerAppCase):
 
     class _Worker(object):
@@ -504,7 +494,7 @@ class test_signal_handlers(WorkerAppCase):
             with self.assertRaises(WorkerTerminate):
                 next_handlers['SIGINT']('SIGINT', object())
 
-    @skip_unless_module('multiprocessing')
+    @skip.unless_module('multiprocessing')
     def test_worker_int_handler_only_stop_MainProcess(self):
         process = current_process()
         name, process.name = process.name, 'OtherProcess'
@@ -535,7 +525,7 @@ class test_signal_handlers(WorkerAppCase):
         handlers = self.psig(cd.install_HUP_not_supported_handler, worker)
         handlers['SIGHUP']('SIGHUP', object())
 
-    @skip_unless_module('multiprocessing')
+    @skip.unless_module('multiprocessing')
     def test_worker_term_hard_handler_only_stop_MainProcess(self):
         process = current_process()
         name, process.name = process.name, 'OtherProcess'
@@ -586,14 +576,14 @@ class test_signal_handlers(WorkerAppCase):
                 state.should_stop = None
 
     @patch('sys.__stderr__')
-    @skip_if_pypy()
-    @skip_if_jython()
+    @skip.if_pypy()
+    @skip.if_jython()
     def test_worker_cry_handler(self, stderr):
         handlers = self.psig(cd.install_cry_handler)
         self.assertIsNone(handlers['SIGUSR1']('SIGUSR1', object()))
         self.assertTrue(stderr.write.called)
 
-    @skip_unless_module('multiprocessing')
+    @skip.unless_module('multiprocessing')
     def test_worker_term_handler_only_stop_MainProcess(self):
         process = current_process()
         name, process.name = process.name, 'OtherProcess'
@@ -614,7 +604,7 @@ class test_signal_handlers(WorkerAppCase):
             process.name = name
             state.should_stop = None
 
-    @skip_unless_symbol('os.execv')
+    @skip.unless_symbol('os.execv')
     @patch('celery.platforms.close_open_fds')
     @patch('atexit.register')
     @patch('os.close')

+ 9 - 5
celery/tests/case.py

@@ -8,7 +8,6 @@ import os
 import sys
 import threading
 
-from contextlib import contextmanager
 from copy import deepcopy
 from datetime import datetime, timedelta
 from functools import partial, wraps
@@ -22,10 +21,16 @@ from celery.backends.cache import CacheBackend, DummyClient
 from celery.exceptions import CDeprecationWarning, CPendingDeprecationWarning
 from celery.utils.imports import qualname
 
-from ._case import *  # noqa
-from ._case import __all__ as _case_all, Case as _Case, decorator
+from case import (
+    ANY, ContextMock, MagicMock, Mock, call, mock, skip, patch, sentinel,
+)
+from case import Case as _Case
+from case.utils import decorator
+
+__all__ = [
+    'ANY', 'ContextMock', 'MagicMock', 'Mock',
+    'call', 'mock', 'skip', 'patch', 'sentinel',
 
-__all__ = _case_all + [
     'AppCase', 'TaskMessage', 'TaskMessage1',
     'depends_on_current_app', 'assert_signal_called', 'task_message_from_sig',
 ]
@@ -246,7 +251,6 @@ class AppCase(Case):
 
 
 @decorator
-@contextmanager
 def assert_signal_called(signal, **expected):
     handler = Mock()
     call_handler = partial(handler)

+ 2 - 2
celery/tests/concurrency/test_eventlet.py

@@ -9,10 +9,10 @@ from celery.concurrency.eventlet import (
     TaskPool,
 )
 
-from celery.tests.case import AppCase, Mock, patch, skip_if_pypy
+from celery.tests.case import AppCase, Mock, patch, skip
 
 
-@skip_if_pypy()
+@skip.if_pypy()
 class EventletCase(AppCase):
 
     def setup(self):

+ 2 - 2
celery/tests/concurrency/test_gevent.py

@@ -6,7 +6,7 @@ from celery.concurrency.gevent import (
     apply_timeout,
 )
 
-from celery.tests.case import AppCase, Mock, patch, skip_if_pypy
+from celery.tests.case import AppCase, Mock, patch, skip
 
 gevent_modules = (
     'gevent',
@@ -17,7 +17,7 @@ gevent_modules = (
 )
 
 
-@skip_if_pypy()
+@skip.if_pypy()
 class GeventCase(AppCase):
 
     def setup(self):

+ 2 - 2
celery/tests/concurrency/test_pool.py

@@ -5,7 +5,7 @@ import itertools
 
 from billiard.einfo import ExceptionInfo
 
-from celery.tests.case import AppCase, skip_unless_module
+from celery.tests.case import AppCase, skip
 
 
 def do_something(i):
@@ -23,7 +23,7 @@ def raise_something(i):
         return ExceptionInfo()
 
 
-@skip_unless_module('multiprocessing')
+@skip.unless_module('multiprocessing')
 class test_TaskPool(AppCase):
 
     def setup(self):

+ 5 - 7
celery/tests/concurrency/test_prefork.py

@@ -12,9 +12,7 @@ from celery.five import range
 from celery.utils.functional import noop
 from celery.utils.objects import Bunch
 
-from celery.tests.case import (
-    AppCase, Mock, patch, restore_logging, skip_if_win32, skip_unless_module,
-)
+from celery.tests.case import AppCase, Mock, mock, patch, skip
 
 try:
     from celery.concurrency import prefork as mp
@@ -60,7 +58,7 @@ class test_process_initializer(AppCase):
     @patch('celery.platforms.signals')
     @patch('celery.platforms.set_mp_process_title')
     def test_process_initializer(self, set_mp_process_title, _signals):
-        with restore_logging():
+        with mock.restore_logging():
             from celery import signals
             from celery._state import _tls
             from celery.concurrency.prefork import (
@@ -186,12 +184,12 @@ class ExeMockTaskPool(mp.TaskPool):
     Pool = BlockingPool = ExeMockPool
 
 
-@skip_unless_module('multiprocessing')
+@skip.unless_module('multiprocessing')
 class PoolCase(AppCase):
     pass
 
 
-@skip_if_win32
+@skip.if_win32()
 class test_AsynPool(PoolCase):
 
     def test_gen_not_started(self):
@@ -297,7 +295,7 @@ class test_AsynPool(PoolCase):
         w.outq.put.assert_called_with((asynpool.WORKER_UP, (1234,)))
 
 
-@skip_if_win32
+@skip.if_win32()
 class test_ResultHandler(PoolCase):
 
     def test_process_result(self):

+ 6 - 6
celery/tests/concurrency/test_threads.py

@@ -2,7 +2,7 @@ from __future__ import absolute_import, unicode_literals
 
 from celery.concurrency.threads import NullDict, TaskPool, apply_target
 
-from celery.tests.case import AppCase, Case, Mock, mask_modules, mock_module
+from celery.tests.case import AppCase, Case, Mock, mock
 
 
 class test_NullDict(Case):
@@ -18,32 +18,32 @@ class test_TaskPool(AppCase):
 
     def test_without_threadpool(self):
 
-        with mask_modules('threadpool'):
+        with mock.mask_modules('threadpool'):
             with self.assertRaises(ImportError):
                 TaskPool(app=self.app)
 
     def test_with_threadpool(self):
-        with mock_module('threadpool'):
+        with mock.module('threadpool'):
             x = TaskPool(app=self.app)
             self.assertTrue(x.ThreadPool)
             self.assertTrue(x.WorkRequest)
 
     def test_on_start(self):
-        with mock_module('threadpool'):
+        with mock.module('threadpool'):
             x = TaskPool(app=self.app)
             x.on_start()
             self.assertTrue(x._pool)
             self.assertIsInstance(x._pool.workRequests, NullDict)
 
     def test_on_stop(self):
-        with mock_module('threadpool'):
+        with mock.module('threadpool'):
             x = TaskPool(app=self.app)
             x.on_start()
             x.on_stop()
             x._pool.dismissWorkers.assert_called_with(x.limit, do_join=True)
 
     def test_on_apply(self):
-        with mock_module('threadpool'):
+        with mock.module('threadpool'):
             x = TaskPool(app=self.app)
             x.on_start()
             callback = Mock()

+ 5 - 5
celery/tests/contrib/test_migrate.py

@@ -26,7 +26,7 @@ from celery.contrib.migrate import (
     move,
 )
 from celery.utils.encoding import bytes_t, ensure_bytes
-from celery.tests.case import AppCase, Mock, override_stdouts, patch
+from celery.tests.case import AppCase, Mock, mock, patch
 
 # hack to ignore error at shutdown
 QoS.restore_at_shutdown = False
@@ -213,10 +213,10 @@ class test_utils(AppCase):
         self.assertEqual(_maybe_queue(app, 'foo'), 313)
         self.assertEqual(_maybe_queue(app, Queue('foo')), Queue('foo'))
 
-    def test_filter_status(self):
-        with override_stdouts() as (stdout, stderr):
-            filter_status(State(), {'id': '1', 'task': 'add'}, Mock())
-            self.assertTrue(stdout.getvalue())
+    @mock.stdouts
+    def test_filter_status(self, stdout, stderr):
+        filter_status(State(), {'id': '1', 'task': 'add'}, Mock())
+        self.assertTrue(stdout.getvalue())
 
     def test_move_by_taskmap(self):
         with patch('celery.contrib.migrate.move') as move:

+ 3 - 3
celery/tests/contrib/test_rdb.py

@@ -9,7 +9,7 @@ from celery.contrib.rdb import (
     set_trace,
 )
 from celery.five import WhateverIO
-from celery.tests.case import AppCase, Mock, patch, skip_if_pypy
+from celery.tests.case import AppCase, Mock, patch, skip
 
 
 class SockErr(socket.error):
@@ -32,7 +32,7 @@ class test_Rdb(AppCase):
         self.assertTrue(debugger.return_value.set_trace.called)
 
     @patch('celery.contrib.rdb.Rdb.get_avail_port')
-    @skip_if_pypy()
+    @skip.if_pypy()
     def test_rdb(self, get_avail_port):
         sock = Mock()
         get_avail_port.return_value = (sock, 8000)
@@ -76,7 +76,7 @@ class test_Rdb(AppCase):
             rdb.set_quit.assert_called_with()
 
     @patch('socket.socket')
-    @skip_if_pypy()
+    @skip.if_pypy()
     def test_get_avail_port(self, sock):
         out = WhateverIO()
         sock.return_value.accept.return_value = (Mock(), ['helu'])

+ 2 - 2
celery/tests/events/test_cursesmon.py

@@ -1,6 +1,6 @@
 from __future__ import absolute_import, unicode_literals
 
-from celery.tests.case import AppCase, skip_unless_module
+from celery.tests.case import AppCase, skip
 
 
 class MockWindow(object):
@@ -9,7 +9,7 @@ class MockWindow(object):
         return self.y, self.x
 
 
-@skip_unless_module('curses')
+@skip.unless_module('curses', import_errors=(ImportError, OSError))
 class test_CursesDisplay(AppCase):
 
     def setup(self):

+ 11 - 10
celery/tests/events/test_snapshot.py

@@ -2,7 +2,8 @@ from __future__ import absolute_import, unicode_literals
 
 from celery.events import Events
 from celery.events.snapshot import Polaroid, evcam
-from celery.tests.case import AppCase, patch, restore_logging
+
+from celery.tests.case import AppCase, mock, patch
 
 
 class TRef(object):
@@ -113,16 +114,16 @@ class test_evcam(AppCase):
         self.app.events = self.MockEvents()
         self.app.events.app = self.app
 
+    @mock.restore_logging()
     def test_evcam(self):
-        with restore_logging():
-            evcam(Polaroid, timer=timer, app=self.app)
-            evcam(Polaroid, timer=timer, loglevel='CRITICAL', app=self.app)
-            self.MockReceiver.raise_keyboard_interrupt = True
-            try:
-                with self.assertRaises(SystemExit):
-                    evcam(Polaroid, timer=timer, app=self.app)
-            finally:
-                self.MockReceiver.raise_keyboard_interrupt = False
+        evcam(Polaroid, timer=timer, app=self.app)
+        evcam(Polaroid, timer=timer, loglevel='CRITICAL', app=self.app)
+        self.MockReceiver.raise_keyboard_interrupt = True
+        try:
+            with self.assertRaises(SystemExit):
+                evcam(Polaroid, timer=timer, app=self.app)
+        finally:
+            self.MockReceiver.raise_keyboard_interrupt = False
 
     @patch('celery.platforms.create_pidlock')
     def test_evcam_pidfile(self, create_pidlock):

+ 2 - 2
celery/tests/events/test_state.py

@@ -19,7 +19,7 @@ from celery.events.state import (
 )
 from celery.five import range
 from celery.utils import uuid
-from celery.tests.case import AppCase, Mock, patch, todo
+from celery.tests.case import AppCase, Mock, patch, skip
 
 try:
     Decimal(2.6)
@@ -374,7 +374,7 @@ class test_State(AppCase):
         self.assertEqual(now[1][0], tC)
         self.assertEqual(now[2][0], tB)
 
-    @todo(reason='not working')
+    @skip.todo(reason='not working')
     def test_task_descending_clock_ordering(self):
         state = State()
         r = ev_logical_clock_ordering(state)

+ 11 - 13
celery/tests/fixups/test_django.py

@@ -12,9 +12,7 @@ from celery.fixups.django import (
 )
 from celery.utils.objects import Bunch
 
-from celery.tests.case import (
-    AppCase, Mock, patch, patch_modules, mask_modules,
-)
+from celery.tests.case import AppCase, Mock, mock, patch
 
 
 class FixupCase(AppCase):
@@ -78,11 +76,11 @@ class test_DjangoFixup(FixupCase):
                 fixup(self.app)
                 self.assertFalse(Fixup.called)
             with patch.dict(os.environ, DJANGO_SETTINGS_MODULE='settings'):
-                with mask_modules('django'):
+                with mock.mask_modules('django'):
                     with self.assertWarnsRegex(UserWarning, 'but Django is'):
                         fixup(self.app)
                         self.assertFalse(Fixup.called)
-                with patch_modules('django'):
+                with mock.patch_modules('django'):
                     fixup(self.app)
                     self.assertTrue(Fixup.called)
 
@@ -334,7 +332,7 @@ class test_DjangoWorkerFixup(FixupCase):
         django.setup.assert_called_with()
 
     def test_mysql_errors(self):
-        with patch_modules('MySQLdb'):
+        with mock.patch_modules('MySQLdb'):
             import MySQLdb as mod
             mod.DatabaseError = Mock()
             mod.InterfaceError = Mock()
@@ -343,12 +341,12 @@ class test_DjangoWorkerFixup(FixupCase):
                 self.assertIn(mod.DatabaseError, f.database_errors)
                 self.assertIn(mod.InterfaceError, f.database_errors)
                 self.assertIn(mod.OperationalError, f.database_errors)
-        with mask_modules('MySQLdb'):
+        with mock.mask_modules('MySQLdb'):
             with self.fixup_context(self.app):
                 pass
 
     def test_pg_errors(self):
-        with patch_modules('psycopg2'):
+        with mock.patch_modules('psycopg2'):
             import psycopg2 as mod
             mod.DatabaseError = Mock()
             mod.InterfaceError = Mock()
@@ -357,12 +355,12 @@ class test_DjangoWorkerFixup(FixupCase):
                 self.assertIn(mod.DatabaseError, f.database_errors)
                 self.assertIn(mod.InterfaceError, f.database_errors)
                 self.assertIn(mod.OperationalError, f.database_errors)
-        with mask_modules('psycopg2'):
+        with mock.mask_modules('psycopg2'):
             with self.fixup_context(self.app):
                 pass
 
     def test_sqlite_errors(self):
-        with patch_modules('sqlite3'):
+        with mock.patch_modules('sqlite3'):
             import sqlite3 as mod
             mod.DatabaseError = Mock()
             mod.InterfaceError = Mock()
@@ -371,12 +369,12 @@ class test_DjangoWorkerFixup(FixupCase):
                 self.assertIn(mod.DatabaseError, f.database_errors)
                 self.assertIn(mod.InterfaceError, f.database_errors)
                 self.assertIn(mod.OperationalError, f.database_errors)
-        with mask_modules('sqlite3'):
+        with mock.mask_modules('sqlite3'):
             with self.fixup_context(self.app):
                 pass
 
     def test_oracle_errors(self):
-        with patch_modules('cx_Oracle'):
+        with mock.patch_modules('cx_Oracle'):
             import cx_Oracle as mod
             mod.DatabaseError = Mock()
             mod.InterfaceError = Mock()
@@ -385,6 +383,6 @@ class test_DjangoWorkerFixup(FixupCase):
                 self.assertIn(mod.DatabaseError, f.database_errors)
                 self.assertIn(mod.InterfaceError, f.database_errors)
                 self.assertIn(mod.OperationalError, f.database_errors)
-        with mask_modules('cx_Oracle'):
+        with mock.mask_modules('cx_Oracle'):
             with self.fixup_context(self.app):
                 pass

+ 2 - 2
celery/tests/security/case.py

@@ -1,8 +1,8 @@
 from __future__ import absolute_import, unicode_literals
 
-from celery.tests.case import AppCase, skip_unless_module
+from celery.tests.case import AppCase, skip
 
 
-@skip_unless_module('OpenSSL.crypto', name='pyOpenSSL')
+@skip.unless_module('OpenSSL.crypto', name='pyOpenSSL')
 class SecurityCase(AppCase):
     pass

+ 3 - 3
celery/tests/security/test_certificate.py

@@ -6,7 +6,7 @@ from celery.security.certificate import Certificate, CertStore, FSCertStore
 from . import CERT1, CERT2, KEY1
 from .case import SecurityCase
 
-from celery.tests.case import Mock, mock_open, patch, todo
+from celery.tests.case import Mock, mock, patch, skip
 
 
 class test_Certificate(SecurityCase):
@@ -27,7 +27,7 @@ class test_Certificate(SecurityCase):
         with self.assertRaises(SecurityError):
             Certificate(KEY1)
 
-    @todo(reason='cert expired')
+    @skip.todo(reason='cert expired')
     def test_has_expired(self):
         self.assertFalse(Certificate(CERT1).has_expired())
 
@@ -68,7 +68,7 @@ class test_FSCertStore(SecurityCase):
         cert.has_expired.return_value = False
         isdir.return_value = True
         glob.return_value = ['foo.cert']
-        with mock_open():
+        with mock.open():
             cert.get_id.return_value = 1
             x = FSCertStore('/var/certs')
             self.assertIn(1, x._certs)

+ 2 - 2
celery/tests/security/test_security.py

@@ -26,7 +26,7 @@ from kombu.serialization import registry
 
 from .case import SecurityCase
 
-from celery.tests.case import Mock, mock_open, patch
+from celery.tests.case import Mock, mock, patch
 
 
 class test_security(SecurityCase):
@@ -86,7 +86,7 @@ class test_security(SecurityCase):
                 calls[0] += 1
 
         self.app.conf.task_serializer = 'auth'
-        with mock_open(side_effect=effect):
+        with mock.open(side_effect=effect):
             with patch('celery.security.registry') as registry:
                 store = Mock()
                 self.app.setup_security(['json'], key, cert, store)

+ 0 - 0
celery/tests/slow/__init__.py


+ 2 - 2
celery/tests/utils/test_datastructures.py

@@ -18,7 +18,7 @@ from celery.datastructures import (
 from celery.five import WhateverIO, items
 from celery.utils.objects import Bunch
 
-from celery.tests.case import Case, Mock, skip_if_win32
+from celery.tests.case import Case, Mock, skip
 
 
 class test_DictAttribute(Case):
@@ -167,7 +167,7 @@ class test_ExceptionInfo(Case):
             self.assertTrue(r)
 
 
-@skip_if_win32()
+@skip.if_win32()
 class test_LimitedSet(Case):
 
     def test_add(self):

+ 21 - 25
celery/tests/utils/test_platforms.py

@@ -38,11 +38,7 @@ try:
 except ImportError:  # pragma: no cover
     resource = None  # noqa
 
-from celery.tests._case import open_fqdn
-from celery.tests.case import (
-    Case, Mock,
-    call, override_stdouts, mock_open, patch, skip_if_win32,
-)
+from celery.tests.case import Case, Mock, call, mock, patch, skip
 
 
 class test_find_option_with_arg(Case):
@@ -60,7 +56,7 @@ class test_find_option_with_arg(Case):
         )
 
 
-@skip_if_win32()
+@skip.if_win32()
 class test_fd_by_path(Case):
 
     def test_finds(self):
@@ -141,7 +137,7 @@ class test_Signals(Case):
         self.assertTrue(signals.supported('INT'))
         self.assertFalse(signals.supported('SIGIMAGINARY'))
 
-    @skip_if_win32()
+    @skip.if_win32()
     def test_reset_alarm(self):
         with patch('signal.alarm') as _alarm:
             signals.reset_alarm()
@@ -186,7 +182,7 @@ class test_Signals(Case):
         signals['INT'] = lambda *a: a
 
 
-@skip_if_win32()
+@skip.if_win32()
 class test_get_fdmax(Case):
 
     @patch('resource.getrlimit')
@@ -205,7 +201,7 @@ class test_get_fdmax(Case):
             self.assertEqual(get_fdmax(None), 13)
 
 
-@skip_if_win32()
+@skip.if_win32()
 class test_maybe_drop_privileges(Case):
 
     def test_on_windows(self):
@@ -322,7 +318,7 @@ class test_maybe_drop_privileges(Case):
         self.assertFalse(setuid.called)
 
 
-@skip_if_win32()
+@skip.if_win32()
 class test_setget_uid_gid(Case):
 
     @patch('celery.platforms.parse_uid')
@@ -380,7 +376,7 @@ class test_setget_uid_gid(Case):
             parse_gid('group')
 
 
-@skip_if_win32()
+@skip.if_win32()
 class test_initgroups(Case):
 
     @patch('pwd.getpwuid')
@@ -416,7 +412,7 @@ class test_initgroups(Case):
                 os.initgroups = prev
 
 
-@skip_if_win32()
+@skip.if_win32()
 class test_detached(Case):
 
     def test_without_resource(self):
@@ -431,7 +427,7 @@ class test_detached(Case):
     @patch('celery.platforms.signals')
     @patch('celery.platforms.maybe_drop_privileges')
     @patch('os.geteuid')
-    @patch(open_fqdn)
+    @patch(mock.open_fqdn)
     def test_default(self, open, geteuid, maybe_drop,
                      signals, pidlock):
         geteuid.return_value = 0
@@ -456,7 +452,7 @@ class test_detached(Case):
         pidlock.assert_called_with('/foo/bar/pid')
 
 
-@skip_if_win32()
+@skip.if_win32()
 class test_DaemonContext(Case):
 
     @patch('os.fork')
@@ -522,7 +518,7 @@ class test_DaemonContext(Case):
             x.open()
 
 
-@skip_if_win32()
+@skip.if_win32()
 class test_Pidfile(Case):
 
     @patch('celery.platforms.Pidfile')
@@ -530,7 +526,7 @@ class test_Pidfile(Case):
         p = Pidfile.return_value = Mock()
         p.is_locked.return_value = True
         p.remove_if_stale.return_value = False
-        with override_stdouts() as (_, err):
+        with mock.stdouts() as (_, err):
             with self.assertRaises(SystemExit):
                 create_pidlock('/var/pid')
             self.assertIn('already exists', err.getvalue())
@@ -567,14 +563,14 @@ class test_Pidfile(Case):
         self.assertFalse(p.is_locked())
 
     def test_read_pid(self):
-        with mock_open() as s:
+        with mock.open() as s:
             s.write('1816\n')
             s.seek(0)
             p = Pidfile('/var/pid')
             self.assertEqual(p.read_pid(), 1816)
 
     def test_read_pid_partially_written(self):
-        with mock_open() as s:
+        with mock.open() as s:
             s.write('1816')
             s.seek(0)
             p = Pidfile('/var/pid')
@@ -584,20 +580,20 @@ class test_Pidfile(Case):
     def test_read_pid_raises_ENOENT(self):
         exc = IOError()
         exc.errno = errno.ENOENT
-        with mock_open(side_effect=exc):
+        with mock.open(side_effect=exc):
             p = Pidfile('/var/pid')
             self.assertIsNone(p.read_pid())
 
     def test_read_pid_raises_IOError(self):
         exc = IOError()
         exc.errno = errno.EAGAIN
-        with mock_open(side_effect=exc):
+        with mock.open(side_effect=exc):
             p = Pidfile('/var/pid')
             with self.assertRaises(IOError):
                 p.read_pid()
 
     def test_read_pid_bogus_pidfile(self):
-        with mock_open() as s:
+        with mock.open() as s:
             s.write('eighteensixteen\n')
             s.seek(0)
             p = Pidfile('/var/pid')
@@ -655,7 +651,7 @@ class test_Pidfile(Case):
 
     @patch('os.kill')
     def test_remove_if_stale_process_dead(self, kill):
-        with override_stdouts():
+        with mock.stdouts():
             p = Pidfile('/var/pid')
             p.read_pid = Mock()
             p.read_pid.return_value = 1816
@@ -668,7 +664,7 @@ class test_Pidfile(Case):
             p.remove.assert_called_with()
 
     def test_remove_if_stale_broken_pid(self):
-        with override_stdouts():
+        with mock.stdouts():
             p = Pidfile('/var/pid')
             p.read_pid = Mock()
             p.read_pid.side_effect = ValueError()
@@ -690,7 +686,7 @@ class test_Pidfile(Case):
     @patch('os.getpid')
     @patch('os.open')
     @patch('os.fdopen')
-    @patch(open_fqdn)
+    @patch(mock.open_fqdn)
     def test_write_pid(self, open_, fdopen, osopen, getpid, fsync):
         getpid.return_value = 1816
         osopen.return_value = 13
@@ -717,7 +713,7 @@ class test_Pidfile(Case):
     @patch('os.getpid')
     @patch('os.open')
     @patch('os.fdopen')
-    @patch(open_fqdn)
+    @patch(mock.open_fqdn)
     def test_write_reread_fails(self, open_, fdopen,
                                 osopen, getpid, fsync):
         getpid.return_value = 1816

+ 2 - 2
celery/tests/utils/test_serialization.py

@@ -7,7 +7,7 @@ from celery.utils.serialization import (
     get_pickleable_etype,
 )
 
-from celery.tests.case import Case, mask_modules
+from celery.tests.case import Case, mock
 
 
 class test_AAPickle(Case):
@@ -15,7 +15,7 @@ class test_AAPickle(Case):
     def test_no_cpickle(self):
         prev = sys.modules.pop('celery.utils.serialization', None)
         try:
-            with mask_modules('cPickle'):
+            with mock.mask_modules('cPickle'):
                 from celery.utils.serialization import pickle
                 import pickle as orig_pickle
                 self.assertIs(pickle.dumps, orig_pickle.dumps)

+ 3 - 3
celery/tests/utils/test_sysinfo.py

@@ -2,10 +2,10 @@ from __future__ import absolute_import, unicode_literals
 
 from celery.utils.sysinfo import load_average, df
 
-from celery.tests.case import Case, patch, skip_unless_symbol
+from celery.tests.case import Case, patch, skip
 
 
-@skip_unless_symbol('os.getloadavg')
+@skip.unless_symbol('os.getloadavg')
 class test_load_average(Case):
 
     def test_avg(self):
@@ -16,7 +16,7 @@ class test_load_average(Case):
             self.assertEqual(l, (0.55, 0.64, 0.7))
 
 
-@skip_unless_symbol('posix.statvfs_result')
+@skip.unless_symbol('posix.statvfs_result')
 class test_df(Case):
 
     def test_df(self):

+ 2 - 2
celery/tests/utils/test_term.py

@@ -7,10 +7,10 @@ from celery.utils import term
 from celery.utils.term import colored, fg
 from celery.five import text_t
 
-from celery.tests.case import Case, skip_if_win32
+from celery.tests.case import Case, skip
 
 
-@skip_if_win32()
+@skip.if_win32()
 class test_colored(Case):
 
     def setUp(self):

+ 2 - 2
celery/tests/utils/test_threads.py

@@ -8,7 +8,7 @@ from celery.utils.threads import (
     bgThread,
 )
 
-from celery.tests.case import Case, override_stdouts, patch
+from celery.tests.case import Case, mock, patch
 
 
 class test_bgThread(Case):
@@ -21,7 +21,7 @@ class test_bgThread(Case):
                 raise KeyError()
 
         with patch('os._exit') as _exit:
-            with override_stdouts():
+            with mock.stdouts():
                 _exit.side_effect = ValueError()
                 t = T()
                 with self.assertRaises(ValueError):

+ 1 - 1
celery/tests/utils/test_timer2.py

@@ -52,7 +52,7 @@ class test_Timer(Case):
         t = timer2.Timer(on_tick=on_tick)
         ne = t._next_entry = Mock(name='_next_entry')
         ne.return_value = 3.33
-        self.on_nth_call_do(ne, t._is_shutdown.set, 3)
+        ne.on_nth_call_do(t._is_shutdown.set, 3)
         t.run()
         sleep.assert_called_with(3.33)
         on_tick.assert_has_calls([call(3.33), call(3.33), call(3.33)])

+ 3 - 3
celery/tests/worker/test_autoreload.py

@@ -18,7 +18,7 @@ from celery.worker.autoreload import (
     Autoreloader,
 )
 
-from celery.tests.case import AppCase, Case, Mock, patch, mock_open
+from celery.tests.case import AppCase, Case, Mock, mock, patch
 
 
 class test_WorkerComponent(AppCase):
@@ -52,11 +52,11 @@ class test_WorkerComponent(AppCase):
 class test_file_hash(Case):
 
     def test_hash(self):
-        with mock_open() as a:
+        with mock.open() as a:
             a.write('the quick brown fox\n')
             a.seek(0)
             A = file_hash('foo')
-        with mock_open() as b:
+        with mock.open() as b:
             b.write('the quick brown bar\n')
             b.seek(0)
             B = file_hash('bar')

+ 4 - 2
celery/tests/worker/test_autoscale.py

@@ -6,9 +6,10 @@ from celery.concurrency.base import BasePool
 from celery.five import monotonic
 from celery.worker import state
 from celery.worker import autoscale
-from celery.tests.case import AppCase, Mock, patch, sleepdeprived
 from celery.utils.objects import Bunch
 
+from celery.tests.case import AppCase, Mock, mock, patch
+
 
 class MockPool(BasePool):
     shrink_raises_exception = False
@@ -88,7 +89,7 @@ class test_Autoscaler(AppCase):
         x.stop()
         self.assertFalse(x.joined)
 
-    @sleepdeprived(autoscale)
+    @mock.sleepdeprived(module=autoscale)
     def test_body(self):
         worker = Mock(name='worker')
         x = autoscale.Autoscaler(self.pool, 10, 3, worker=worker)
@@ -193,6 +194,7 @@ class test_Autoscaler(AppCase):
         _exit.assert_called_with(1)
         self.assertTrue(stderr.write.call_count)
 
+    @mock.sleepdeprived(module=autoscale)
     def test_no_negative_scale(self):
         total_num_processes = []
         worker = Mock(name='worker')

+ 2 - 2
celery/tests/worker/test_components.py

@@ -7,7 +7,7 @@ from __future__ import absolute_import, unicode_literals
 from celery.exceptions import ImproperlyConfigured
 from celery.worker.components import Beat, Hub, Pool, Timer
 
-from celery.tests.case import AppCase, Mock, patch, skip_if_win32
+from celery.tests.case import AppCase, Mock, patch, skip
 
 
 class test_Timer(AppCase):
@@ -60,7 +60,7 @@ class test_Pool(AppCase):
         comp.close(w)
         comp.terminate(w)
 
-    @skip_if_win32()
+    @skip.if_win32()
     def test_create_when_eventloop(self):
         w = Mock()
         w.use_eventloop = w.pool_putlocks = w.pool_cls.uses_semaphore = True

+ 2 - 4
celery/tests/worker/test_consumer.py

@@ -13,9 +13,7 @@ from celery.worker.consumer.heart import Heart
 from celery.worker.consumer.mingle import Mingle
 from celery.worker.consumer.tasks import Tasks
 
-from celery.tests.case import (
-    AppCase, ContextMock, Mock, call, patch, skip_if_python3,
-)
+from celery.tests.case import AppCase, ContextMock, Mock, call, patch, skip
 
 
 class test_Consumer(AppCase):
@@ -45,7 +43,7 @@ class test_Consumer(AppCase):
         c = self.get_consumer()
         self.assertIsNone(c.task_buckets['fooxasdwx.wewe'])
 
-    @skip_if_python3(reason='buffer type not available')
+    @skip.if_python3(reason='buffer type not available')
     def test_dump_body_buffer(self):
         msg = Mock()
         msg.body = 'str'

+ 2 - 2
celery/tests/worker/test_request.py

@@ -48,7 +48,7 @@ from celery.tests.case import (
     TaskMessage,
     task_message_from_sig,
     patch,
-    skip_if_python3,
+    skip,
 )
 
 
@@ -123,7 +123,7 @@ def jail(app, task_id, name, args, kwargs):
     ).retval
 
 
-@skip_if_python3
+@skip.if_python3()
 class test_default_encode(AppCase):
 
     def test_jython(self):

+ 2 - 2
celery/tests/worker/test_worker.py

@@ -30,7 +30,7 @@ from celery.utils import worker_direct
 from celery.utils.serialization import pickle
 from celery.utils.timer2 import Timer
 
-from celery.tests.case import AppCase, Mock, TaskMessage, patch, todo
+from celery.tests.case import AppCase, Mock, TaskMessage, patch, skip
 
 
 def MockStep(step=None):
@@ -849,7 +849,7 @@ class test_WorkController(AppCase):
             self.worker._send_worker_shutdown()
             ws.send.assert_called_with(sender=self.worker)
 
-    @todo('unstable test')
+    @skip.todo('unstable test')
     def test_process_shutdown_on_worker_shutdown(self):
         from celery.concurrency.prefork import process_destructor
         from celery.concurrency.asynpool import Worker

+ 1 - 3
requirements/test.txt

@@ -1,3 +1 @@
--r deps/mock.txt
--r deps/nose.txt
-unittest2>=0.5.1
+case

+ 0 - 1
requirements/test3.txt

@@ -1 +0,0 @@
--r deps/nose.txt

+ 1 - 5
setup.py

@@ -171,10 +171,6 @@ install_requires = reqs('default.txt')
 if JYTHON:
     install_requires.extend(reqs('jython.txt'))
 
-# -*- Tests Requires -*-
-
-tests_require = reqs('test3.txt' if PY3 else 'test.txt')
-
 # -*- Long Description -*-
 
 if os.path.exists('README.rst'):
@@ -219,7 +215,7 @@ setup(
     include_package_data=False,
     zip_safe=False,
     install_requires=install_requires,
-    tests_require=tests_require,
+    tests_require=reqs('test.txt'),
     test_suite='nose.collector',
     classifiers=classifiers,
     entry_points=entrypoints,

+ 1 - 5
tox.ini

@@ -4,15 +4,11 @@ envlist = 2.7,pypy,3.4,3.5,pypy3,flake8,flakeplus
 [testenv]
 deps=
     -r{toxinidir}/requirements/default.txt
+    -r{toxinidir}/requirements/test.txt
 
-    2.7,pypy: -r{toxinidir}/requirements/test.txt
     2.7: -r{toxinidir}/requirements/test-ci-default.txt
-
-    3.4,3.5,pypy3: -r{toxinidir}/requirements/test3.txt
     3.4,3.5: -r{toxinidir}/requirements/test-ci-default.txt
-
     pypy,pypy3: -r{toxinidir}/requirements/test-ci-base.txt
-    pypy3: -r{toxinidir}/requirements/test-pypy3.txt
 
 sitepackages = False
 recreate = False