Browse Source

[tests] Cleanup SkipTest stuff

Ask Solem 9 years ago
parent
commit
1a572fb55c
40 changed files with 1549 additions and 1540 deletions
  1. 799 0
      celery/tests/_case.py
  2. 2 2
      celery/tests/app/test_app.py
  3. 2 6
      celery/tests/app/test_beat.py
  4. 2 2
      celery/tests/app/test_loaders.py
  5. 6 6
      celery/tests/app/test_log.py
  6. 3 6
      celery/tests/app/test_schedules.py
  7. 4 3
      celery/tests/backends/test_base.py
  8. 2 2
      celery/tests/backends/test_cache.py
  9. 2 52
      celery/tests/backends/test_couchbase.py
  10. 2 3
      celery/tests/backends/test_couchdb.py
  11. 6 13
      celery/tests/backends/test_database.py
  12. 2 8
      celery/tests/backends/test_elasticsearch.py
  13. 2 4
      celery/tests/backends/test_filesystem.py
  14. 7 13
      celery/tests/backends/test_mongodb.py
  15. 6 8
      celery/tests/backends/test_redis.py
  16. 3 4
      celery/tests/backends/test_riak.py
  17. 2 1
      celery/tests/bin/test_amqp.py
  18. 3 2
      celery/tests/bin/test_celery.py
  19. 2 1
      celery/tests/bin/test_celeryevdump.py
  20. 2 6
      celery/tests/bin/test_events.py
  21. 3 3
      celery/tests/bin/test_multi.py
  22. 35 61
      celery/tests/bin/test_worker.py
  23. 33 682
      celery/tests/case.py
  24. 1 2
      celery/tests/concurrency/test_eventlet.py
  25. 1 1
      celery/tests/concurrency/test_gevent.py
  26. 2 5
      celery/tests/concurrency/test_pool.py
  27. 7 16
      celery/tests/concurrency/test_prefork.py
  28. 4 3
      celery/tests/contrib/test_rdb.py
  29. 2 6
      celery/tests/events/test_cursesmon.py
  30. 2 2
      celery/tests/events/test_state.py
  31. 3 7
      celery/tests/security/case.py
  32. 2 2
      celery/tests/security/test_certificate.py
  33. 3 11
      celery/tests/utils/test_datastructures.py
  34. 571 561
      celery/tests/utils/test_platforms.py
  35. 3 9
      celery/tests/utils/test_sysinfo.py
  36. 2 4
      celery/tests/utils/test_term.py
  37. 2 4
      celery/tests/worker/test_components.py
  38. 5 6
      celery/tests/worker/test_consumer.py
  39. 7 11
      celery/tests/worker/test_request.py
  40. 2 2
      celery/tests/worker/test_worker.py

+ 799 - 0
celery/tests/_case.py

@@ -0,0 +1,799 @@
+from __future__ import absolute_import, unicode_literals
+
+import importlib
+import inspect
+import io
+import logging
+import os
+import platform
+import re
+import sys
+import time
+import types
+import warnings
+
+from contextlib import contextmanager
+from functools import partial, wraps
+from six import (
+    iteritems as items,
+    itervalues as values,
+    string_types,
+    reraise,
+)
+from six.moves import builtins
+
+from nose import SkipTest
+
+try:
+    import unittest  # noqa
+    unittest.skip
+    from unittest.util import safe_repr, unorderable_list_difference
+except AttributeError:
+    import unittest2 as unittest  # noqa
+    from unittest2.util import safe_repr, unorderable_list_difference  # noqa
+
+try:
+    from unittest import mock
+except ImportError:
+    import mock  # noqa
+
+__all__ = [
+    'ANY', 'Case', 'ContextMock', 'MagicMock', 'Mock', 'MockCallbacks',
+    'call', 'patch', 'sentinel',
+
+    'mock_open', 'mock_context', 'mock_module',
+    'patch_modules', 'reset_modules', 'sys_platform', 'pypy_version',
+    'platform_pyimp', 'replace_module_value', 'override_stdouts',
+    'mask_modules', 'sleepdeprived', 'mock_environ', 'wrap_logger',
+    'restore_logging',
+
+    'todo', 'skip', 'skip_if_darwin', 'skip_if_environ',
+    'skip_if_jython', 'skip_if_platform', 'skip_if_pypy', 'skip_if_python3',
+    'skip_if_win32', 'skip_unless_module', 'skip_unless_symbol',
+]
+
+patch = mock.patch
+call = mock.call
+sentinel = mock.sentinel
+MagicMock = mock.MagicMock
+ANY = mock.ANY
+
+PY3 = sys.version_info[0] == 3
+if PY3:
+    open_fqdn = 'builtins.open'
+    module_name_t = str
+else:
+    open_fqdn = '__builtin__.open'  # noqa
+    module_name_t = bytes  # noqa
+
+StringIO = io.StringIO
+_SIO_write = StringIO.write
+_SIO_init = StringIO.__init__
+
+
+def symbol_by_name(name, aliases={}, imp=None, package=None,
+                   sep='.', default=None, **kwargs):
+    """Get symbol by qualified name.
+
+    The name should be the full dot-separated path to the class::
+
+        modulename.ClassName
+
+    Example::
+
+        celery.concurrency.processes.TaskPool
+                                    ^- class name
+
+    or using ':' to separate module and symbol::
+
+        celery.concurrency.processes:TaskPool
+
+    If `aliases` is provided, a dict containing short name/long name
+    mappings, the name is looked up in the aliases first.
+
+    Examples:
+
+        >>> symbol_by_name('celery.concurrency.processes.TaskPool')
+        <class 'celery.concurrency.processes.TaskPool'>
+
+        >>> symbol_by_name('default', {
+        ...     'default': 'celery.concurrency.processes.TaskPool'})
+        <class 'celery.concurrency.processes.TaskPool'>
+
+        # Does not try to look up non-string names.
+        >>> from celery.concurrency.processes import TaskPool
+        >>> symbol_by_name(TaskPool) is TaskPool
+        True
+
+    """
+    if imp is None:
+        imp = importlib.import_module
+
+    if not isinstance(name, string_types):
+        return name                                 # already a class
+
+    name = aliases.get(name) or name
+    sep = ':' if ':' in name else sep
+    module_name, _, cls_name = name.rpartition(sep)
+    if not module_name:
+        cls_name, module_name = None, package if package else cls_name
+    try:
+        try:
+            module = imp(module_name, package=package, **kwargs)
+        except ValueError as exc:
+            reraise(ValueError,
+                    ValueError("Couldn't import {0!r}: {1}".format(name, exc)),
+                    sys.exc_info()[2])
+        return getattr(module, cls_name) if cls_name else module
+    except (ImportError, AttributeError):
+        if default is None:
+            raise
+    return default
+
+
+class WhateverIO(StringIO):
+
+    def __init__(self, v=None, *a, **kw):
+        _SIO_init(self, v.decode() if isinstance(v, bytes) else v, *a, **kw)
+
+    def write(self, data):
+        _SIO_write(self, data.decode() if isinstance(data, bytes) else data)
+
+
+def noop(*args, **kwargs):
+    pass
+
+
+class Mock(mock.Mock):
+
+    def __init__(self, *args, **kwargs):
+        attrs = kwargs.pop('attrs', None) or {}
+        super(Mock, self).__init__(*args, **kwargs)
+        for attr_name, attr_value in items(attrs):
+            setattr(self, attr_name, attr_value)
+
+
+class _ContextMock(Mock):
+    """Dummy class implementing __enter__ and __exit__
+    as the :keyword:`with` statement requires these to be implemented
+    in the class, not just the instance."""
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, *exc_info):
+        pass
+
+
+def ContextMock(*args, **kwargs):
+    obj = _ContextMock(*args, **kwargs)
+    obj.attach_mock(_ContextMock(), '__enter__')
+    obj.attach_mock(_ContextMock(), '__exit__')
+    obj.__enter__.return_value = obj
+    # if __exit__ return a value the exception is ignored,
+    # so it must return None here.
+    obj.__exit__.return_value = None
+    return obj
+
+
+def _bind(f, o):
+    @wraps(f)
+    def bound_meth(*fargs, **fkwargs):
+        return f(o, *fargs, **fkwargs)
+    return bound_meth
+
+
+if PY3:  # pragma: no cover
+    def _get_class_fun(meth):
+        return meth
+else:
+    def _get_class_fun(meth):
+        return meth.__func__
+
+
+class MockCallbacks(object):
+
+    def __new__(cls, *args, **kwargs):
+        r = Mock(name=cls.__name__)
+        _get_class_fun(cls.__init__)(r, *args, **kwargs)
+        for key, value in items(vars(cls)):
+            if key not in ('__dict__', '__weakref__', '__new__', '__init__'):
+                if inspect.ismethod(value) or inspect.isfunction(value):
+                    r.__getattr__(key).side_effect = _bind(value, r)
+                else:
+                    r.__setattr__(key, value)
+        return r
+
+
+# -- adds assertWarns from recent unittest2, not in Python 2.7.
+
+class _AssertRaisesBaseContext(object):
+
+    def __init__(self, expected, test_case, callable_obj=None,
+                 expected_regex=None):
+        self.expected = expected
+        self.failureException = test_case.failureException
+        self.obj_name = None
+        if isinstance(expected_regex, string_types):
+            expected_regex = re.compile(expected_regex)
+        self.expected_regex = expected_regex
+
+
+def _is_magic_module(m):
+    # some libraries create custom module types that are lazily
+    # lodaded, e.g. Django installs some modules in sys.modules that
+    # will load _tkinter and other shit when touched.
+
+    # pyflakes refuses to accept 'noqa' for this isinstance.
+    cls, modtype = type(m), types.ModuleType
+    try:
+        variables = vars(cls)
+    except TypeError:
+        return True
+    else:
+        return (cls is not modtype and (
+            '__getattr__' in variables or
+            '__getattribute__' in variables))
+
+
+class _AssertWarnsContext(_AssertRaisesBaseContext):
+    """A context manager used to implement TestCase.assertWarns* methods."""
+
+    def __enter__(self):
+        # The __warningregistry__'s need to be in a pristine state for tests
+        # to work properly.
+        warnings.resetwarnings()
+        for v in list(values(sys.modules)):
+            # do not evaluate Django moved modules and other lazily
+            # initialized modules.
+            if v and not _is_magic_module(v):
+                # use raw __getattribute__ to protect even better from
+                # lazily loaded modules
+                try:
+                    object.__getattribute__(v, '__warningregistry__')
+                except AttributeError:
+                    pass
+                else:
+                    object.__setattr__(v, '__warningregistry__', {})
+        self.warnings_manager = warnings.catch_warnings(record=True)
+        self.warnings = self.warnings_manager.__enter__()
+        warnings.simplefilter('always', self.expected)
+        return self
+
+    def __exit__(self, exc_type, exc_value, tb):
+        self.warnings_manager.__exit__(exc_type, exc_value, tb)
+        if exc_type is not None:
+            # let unexpected exceptions pass through
+            return
+        try:
+            exc_name = self.expected.__name__
+        except AttributeError:
+            exc_name = str(self.expected)
+        first_matching = None
+        for m in self.warnings:
+            w = m.message
+            if not isinstance(w, self.expected):
+                continue
+            if first_matching is None:
+                first_matching = w
+            if (self.expected_regex is not None and
+                    not self.expected_regex.search(str(w))):
+                continue
+            # store warning for later retrieval
+            self.warning = w
+            self.filename = m.filename
+            self.lineno = m.lineno
+            return
+        # Now we simply try to choose a helpful failure message
+        if first_matching is not None:
+            raise self.failureException(
+                '%r does not match %r' % (
+                    self.expected_regex.pattern, str(first_matching)))
+        if self.obj_name:
+            raise self.failureException(
+                '%s not triggered by %s' % (exc_name, self.obj_name))
+        else:
+            raise self.failureException('%s not triggered' % exc_name)
+
+
+class Case(unittest.TestCase):
+    DeprecationWarning = DeprecationWarning
+    PendingDeprecationWarning = PendingDeprecationWarning
+
+    def patch(self, *path, **options):
+        manager = patch('.'.join(path), **options)
+        patched = manager.start()
+        self.addCleanup(manager.stop)
+        return patched
+
+    def mock_modules(self, *mods):
+        modules = []
+        for mod in mods:
+            mod = mod.split('.')
+            modules.extend(reversed([
+                '.'.join(mod[:-i] if i else mod) for i in range(len(mod))
+            ]))
+        modules = sorted(set(modules))
+        return self.wrap_context(mock_module(*modules))
+
+    def on_nth_call_do(self, mock, side_effect, n=1):
+
+        def on_call(*args, **kwargs):
+            if mock.call_count >= n:
+                mock.side_effect = side_effect
+            return mock.return_value
+        mock.side_effect = on_call
+        return mock
+
+    def on_nth_call_return(self, mock, retval, n=1):
+
+        def on_call(*args, **kwargs):
+            if mock.call_count >= n:
+                mock.return_value = retval
+            return mock.return_value
+        mock.side_effect = on_call
+        return mock
+
+    def mask_modules(self, *modules):
+        self.wrap_context(mask_modules(*modules))
+
+    def wrap_context(self, context):
+        ret = context.__enter__()
+        self.addCleanup(partial(context.__exit__, None, None, None))
+        return ret
+
+    def mock_environ(self, env_name, env_value):
+        return self.wrap_context(mock_environ(env_name, env_value))
+
+    def assertWarns(self, expected_warning):
+        return _AssertWarnsContext(expected_warning, self, None)
+
+    def assertWarnsRegex(self, expected_warning, expected_regex):
+        return _AssertWarnsContext(expected_warning, self,
+                                   None, expected_regex)
+
+    @contextmanager
+    def assertDeprecated(self):
+        with self.assertWarnsRegex(self.DeprecationWarning,
+                                   r'scheduled for removal'):
+            yield
+
+    @contextmanager
+    def assertPendingDeprecation(self):
+        with self.assertWarnsRegex(self.PendingDeprecationWarning,
+                                   r'scheduled for deprecation'):
+            yield
+
+    def assertDictContainsSubset(self, expected, actual, msg=None):
+        missing, mismatched = [], []
+
+        for key, value in items(expected):
+            if key not in actual:
+                missing.append(key)
+            elif value != actual[key]:
+                mismatched.append('%s, expected: %s, actual: %s' % (
+                    safe_repr(key), safe_repr(value),
+                    safe_repr(actual[key])))
+
+        if not (missing or mismatched):
+            return
+
+        standard_msg = ''
+        if missing:
+            standard_msg = 'Missing: %s' % ','.join(map(safe_repr, missing))
+
+        if mismatched:
+            if standard_msg:
+                standard_msg += '; '
+            standard_msg += 'Mismatched values: %s' % (
+                ','.join(mismatched))
+
+        self.fail(self._formatMessage(msg, standard_msg))
+
+    def assertItemsEqual(self, expected_seq, actual_seq, msg=None):
+        missing = unexpected = None
+        try:
+            expected = sorted(expected_seq)
+            actual = sorted(actual_seq)
+        except TypeError:
+            # Unsortable items (example: set(), complex(), ...)
+            expected = list(expected_seq)
+            actual = list(actual_seq)
+            missing, unexpected = unorderable_list_difference(
+                expected, actual)
+        else:
+            return self.assertSequenceEqual(expected, actual, msg=msg)
+
+        errors = []
+        if missing:
+            errors.append(
+                'Expected, but missing:\n    %s' % (safe_repr(missing),)
+            )
+        if unexpected:
+            errors.append(
+                'Unexpected, but present:\n    %s' % (safe_repr(unexpected),)
+            )
+        if errors:
+            standardMsg = '\n'.join(errors)
+            self.fail(self._formatMessage(msg, standardMsg))
+
+
+class _CallableContext(object):
+
+    def __init__(self, context, cargs, ckwargs, fun):
+        self.context = context
+        self.cargs = cargs
+        self.ckwargs = ckwargs
+        self.fun = fun
+
+    def __call__(self, *args, **kwargs):
+        return self.fun(*args, **kwargs)
+
+    def __enter__(self):
+        self.ctx = self.context(*self.cargs, **self.ckwargs)
+        return self.ctx.__enter__()
+
+    def __exit__(self, *einfo):
+        if self.ctx:
+            return self.ctx.__exit__(*einfo)
+
+
+def decorator(predicate):
+
+    @wraps(predicate)
+    def take_arguments(*pargs, **pkwargs):
+
+        @wraps(predicate)
+        def decorator(cls):
+            if inspect.isclass(cls):
+                orig_setup = cls.setUp
+                orig_teardown = cls.tearDown
+
+                @wraps(cls.setUp)
+                def around_setup(*args, **kwargs):
+                    try:
+                        contexts = args[0].__rb3dc_contexts__
+                    except AttributeError:
+                        contexts = args[0].__rb3dc_contexts__ = []
+                    p = predicate(*pargs, **pkwargs)
+                    p.__enter__()
+                    contexts.append(p)
+                    return orig_setup(*args, **kwargs)
+                around_setup.__wrapped__ = cls.setUp
+                cls.setUp = around_setup
+
+                @wraps(cls.tearDown)
+                def around_teardown(*args, **kwargs):
+                    try:
+                        contexts = args[0].__rb3dc_contexts__
+                    except AttributeError:
+                        pass
+                    else:
+                        for context in contexts:
+                            context.__exit__(*sys.exc_info())
+                    orig_teardown(*args, **kwargs)
+                around_teardown.__wrapped__ = cls.tearDown
+                cls.tearDown = around_teardown
+
+                return cls
+            else:
+                @wraps(cls)
+                def around_case(*args, **kwargs):
+                    with predicate(*pargs, **pkwargs):
+                        return cls(*args, **kwargs)
+                return around_case
+
+        if len(pargs) == 1 and callable(pargs[0]):
+            fun, pargs = pargs[0], ()
+            return decorator(fun)
+        return _CallableContext(predicate, pargs, pkwargs, decorator)
+    return take_arguments
+
+
+@decorator
+@contextmanager
+def skip_unless_module(module, name=None):
+    try:
+        importlib.import_module(module)
+    except (ImportError, OSError):
+        raise SkipTest('module not installed: {0}'.format(name or module))
+    yield
+
+
+@decorator
+@contextmanager
+def skip_unless_symbol(symbol, name=None):
+    try:
+        symbol_by_name(symbol)
+    except (AttributeError, ImportError):
+        raise SkipTest('missing symbol {0}'.format(name or symbol))
+    yield
+
+
+def get_logger_handlers(logger):
+    return [
+        h for h in logger.handlers
+        if not isinstance(h, logging.NullHandler)
+    ]
+
+
+@decorator
+@contextmanager
+def wrap_logger(logger, loglevel=logging.ERROR):
+    old_handlers = get_logger_handlers(logger)
+    sio = WhateverIO()
+    siohandler = logging.StreamHandler(sio)
+    logger.handlers = [siohandler]
+
+    try:
+        yield sio
+    finally:
+        logger.handlers = old_handlers
+
+
+@decorator
+@contextmanager
+def mock_environ(env_name, env_value):
+    sentinel = object()
+    prev_val = os.environ.get(env_name, sentinel)
+    os.environ[env_name] = env_value
+    try:
+        yield env_value
+    finally:
+        if prev_val is sentinel:
+            os.environ.pop(env_name, None)
+        else:
+            os.environ[env_name] = prev_val
+
+
+@decorator
+@contextmanager
+def sleepdeprived(module=time):
+    old_sleep, module.sleep = module.sleep, noop
+    try:
+        yield
+    finally:
+        module.sleep = old_sleep
+
+
+@decorator
+@contextmanager
+def skip_if_python3(reason='incompatible'):
+    if PY3:
+        raise SkipTest('Python3: {0}'.format(reason))
+    yield
+
+
+@decorator
+@contextmanager
+def skip_if_environ(env_var_name):
+    if os.environ.get(env_var_name):
+        raise SkipTest('envvar {0} set'.format(env_var_name))
+    yield
+
+
+@decorator
+@contextmanager
+def _skip_test(reason, sign):
+    raise SkipTest('{0}: {1}'.format(sign, reason))
+    yield
+todo = partial(_skip_test, sign='TODO')
+skip = partial(_skip_test, sign='SKIP')
+
+
+# Taken from
+# http://bitbucket.org/runeh/snippets/src/tip/missing_modules.py
+@decorator
+@contextmanager
+def mask_modules(*modnames):
+    """Ban some modules from being importable inside the context
+
+    For example:
+
+        >>> with mask_modules('sys'):
+        ...     try:
+        ...         import sys
+        ...     except ImportError:
+        ...         print('sys not found')
+        sys not found
+
+        >>> import sys  # noqa
+        >>> sys.version
+        (2, 5, 2, 'final', 0)
+
+    """
+    realimport = builtins.__import__
+
+    def myimp(name, *args, **kwargs):
+        if name in modnames:
+            raise ImportError('No module named %s' % name)
+        else:
+            return realimport(name, *args, **kwargs)
+
+    builtins.__import__ = myimp
+    try:
+        yield True
+    finally:
+        builtins.__import__ = realimport
+
+
+@decorator
+@contextmanager
+def override_stdouts():
+    """Override `sys.stdout` and `sys.stderr` with `WhateverIO`."""
+    prev_out, prev_err = sys.stdout, sys.stderr
+    prev_rout, prev_rerr = sys.__stdout__, sys.__stderr__
+    mystdout, mystderr = WhateverIO(), WhateverIO()
+    sys.stdout = sys.__stdout__ = mystdout
+    sys.stderr = sys.__stderr__ = mystderr
+
+    try:
+        yield mystdout, mystderr
+    finally:
+        sys.stdout = prev_out
+        sys.stderr = prev_err
+        sys.__stdout__ = prev_rout
+        sys.__stderr__ = prev_rerr
+
+
+@decorator
+@contextmanager
+def replace_module_value(module, name, value=None):
+    has_prev = hasattr(module, name)
+    prev = getattr(module, name, None)
+    if value:
+        setattr(module, name, value)
+    else:
+        try:
+            delattr(module, name)
+        except AttributeError:
+            pass
+    try:
+        yield
+    finally:
+        if prev is not None:
+            setattr(module, name, prev)
+        if not has_prev:
+            try:
+                delattr(module, name)
+            except AttributeError:
+                pass
+pypy_version = partial(
+    replace_module_value, sys, 'pypy_version_info',
+)
+platform_pyimp = partial(
+    replace_module_value, platform, 'python_implementation',
+)
+
+
+@decorator
+@contextmanager
+def sys_platform(value):
+    prev, sys.platform = sys.platform, value
+    try:
+        yield
+    finally:
+        sys.platform = prev
+
+
+@decorator
+@contextmanager
+def reset_modules(*modules):
+    prev = {k: sys.modules.pop(k) for k in modules if k in sys.modules}
+    try:
+        yield
+    finally:
+        sys.modules.update(prev)
+
+
+@decorator
+@contextmanager
+def patch_modules(*modules):
+    prev = {}
+    for mod in modules:
+        prev[mod] = sys.modules.get(mod)
+        sys.modules[mod] = types.ModuleType(module_name_t(mod))
+    try:
+        yield
+    finally:
+        for name, mod in items(prev):
+            if mod is None:
+                sys.modules.pop(name, None)
+            else:
+                sys.modules[name] = mod
+
+
+@decorator
+@contextmanager
+def mock_module(*names):
+    prev = {}
+
+    class MockModule(types.ModuleType):
+
+        def __getattr__(self, attr):
+            setattr(self, attr, Mock())
+            return types.ModuleType.__getattribute__(self, attr)
+
+    mods = []
+    for name in names:
+        try:
+            prev[name] = sys.modules[name]
+        except KeyError:
+            pass
+        mod = sys.modules[name] = MockModule(module_name_t(name))
+        mods.append(mod)
+    try:
+        yield mods
+    finally:
+        for name in names:
+            try:
+                sys.modules[name] = prev[name]
+            except KeyError:
+                try:
+                    del(sys.modules[name])
+                except KeyError:
+                    pass
+
+
+@contextmanager
+def mock_context(mock, typ=Mock):
+    context = mock.return_value = Mock()
+    context.__enter__ = typ()
+    context.__exit__ = typ()
+
+    def on_exit(*x):
+        if x[0]:
+            reraise(x[0], x[1], x[2])
+    context.__exit__.side_effect = on_exit
+    context.__enter__.return_value = context
+    try:
+        yield context
+    finally:
+        context.reset()
+
+
+@decorator
+@contextmanager
+def mock_open(typ=WhateverIO, side_effect=None):
+    with patch(open_fqdn) as open_:
+        with mock_context(open_) as context:
+            if side_effect is not None:
+                context.__enter__.side_effect = side_effect
+            val = context.__enter__.return_value = typ()
+            val.__exit__ = Mock()
+            yield val
+
+
+@decorator
+@contextmanager
+def skip_if_platform(platform_name, name=None):
+    if sys.platform.startswith(platform_name):
+        raise SkipTest('does not work on {0}'.format(platform_name or name))
+    yield
+skip_if_jython = partial(skip_if_platform, 'java', name='Jython')
+skip_if_win32 = partial(skip_if_platform, 'win32', name='Windows')
+skip_if_darwin = partial(skip_if_platform, 'darwin', name='OS X')
+
+
+@decorator
+@contextmanager
+def skip_if_pypy():
+    if getattr(sys, 'pypy_version_info', None):
+        raise SkipTest('does not work on PyPy')
+    yield
+
+
+@decorator
+@contextmanager
+def restore_logging():
+    outs = sys.stdout, sys.stderr, sys.__stdout__, sys.__stderr__
+    root = logging.getLogger()
+    level = root.level
+    handlers = root.handlers
+
+    try:
+        yield
+    finally:
+        sys.stdout, sys.stderr, sys.__stdout__, sys.__stderr__ = outs
+        root.level = level
+        root.handlers[:] = handlers

+ 2 - 2
celery/tests/app/test_app.py

@@ -34,7 +34,7 @@ from celery.tests.case import (
     platform_pyimp,
     platform_pyimp,
     sys_platform,
     sys_platform,
     pypy_version,
     pypy_version,
-    with_environ,
+    mock_environ,
 )
 )
 from celery.utils import uuid
 from celery.utils import uuid
 from celery.utils.mail import ErrorMail
 from celery.utils.mail import ErrorMail
@@ -236,7 +236,7 @@ class test_App(AppCase):
             ['A', 'B', 'C', 'D', 'E', 'F'], related_name='tasks',
             ['A', 'B', 'C', 'D', 'E', 'F'], related_name='tasks',
         )
         )
 
 
-    @with_environ('CELERY_BROKER_URL', '')
+    @mock_environ('CELERY_BROKER_URL', '')
     def test_with_broker(self):
     def test_with_broker(self):
         with self.Celery(broker='foo://baribaz') as app:
         with self.Celery(broker='foo://baribaz') as app:
             self.assertEqual(app.conf.broker_url, 'foo://baribaz')
             self.assertEqual(app.conf.broker_url, 'foo://baribaz')

+ 2 - 6
celery/tests/app/test_beat.py

@@ -11,7 +11,7 @@ from celery.schedules import schedule
 from celery.utils import uuid
 from celery.utils import uuid
 from celery.utils.objects import Bunch
 from celery.utils.objects import Bunch
 
 
-from celery.tests.case import AppCase, Mock, SkipTest, call, patch
+from celery.tests.case import AppCase, Mock, call, patch, skip_unless_module
 
 
 
 
 class MockShelve(dict):
 class MockShelve(dict):
@@ -485,12 +485,8 @@ class test_Service(AppCase):
 
 
 class test_EmbeddedService(AppCase):
 class test_EmbeddedService(AppCase):
 
 
+    @skip_unless_module('_multiprocessing', name='multiprocessing')
     def test_start_stop_process(self):
     def test_start_stop_process(self):
-        try:
-            import _multiprocessing  # noqa
-        except ImportError:
-            raise SkipTest('multiprocessing not available')
-
         from billiard.process import Process
         from billiard.process import Process
 
 
         s = beat.EmbeddedService(self.app)
         s = beat.EmbeddedService(self.app)

+ 2 - 2
celery/tests/app/test_loaders.py

@@ -13,7 +13,7 @@ from celery.loaders.app import AppLoader
 from celery.utils.imports import NotAPackage
 from celery.utils.imports import NotAPackage
 from celery.utils.mail import SendmailWarning
 from celery.utils.mail import SendmailWarning
 
 
-from celery.tests.case import AppCase, Case, Mock, patch, with_environ
+from celery.tests.case import AppCase, Case, Mock, mock_environ, patch
 
 
 
 
 class DummyLoader(base.BaseLoader):
 class DummyLoader(base.BaseLoader):
@@ -144,7 +144,7 @@ class test_DefaultLoader(AppCase):
             l.read_configuration(fail_silently=False)
             l.read_configuration(fail_silently=False)
 
 
     @patch('celery.loaders.base.find_module')
     @patch('celery.loaders.base.find_module')
-    @with_environ('CELERY_CONFIG_MODULE', 'celeryconfig.py')
+    @mock_environ('CELERY_CONFIG_MODULE', 'celeryconfig.py')
     def test_read_configuration_py_in_name(self, find_module):
     def test_read_configuration_py_in_name(self, find_module):
         find_module.side_effect = NotAPackage()
         find_module.side_effect = NotAPackage()
         l = default.Loader(app=self.app)
         l = default.Loader(app=self.app)

+ 6 - 6
celery/tests/app/test_log.py

@@ -21,9 +21,10 @@ from celery.utils.log import (
     logger_isa,
     logger_isa,
 )
 )
 from celery.tests.case import (
 from celery.tests.case import (
-    AppCase, Mock, SkipTest, mask_modules,
-    get_handlers, override_stdouts, patch, wrap_logger, restore_logging,
+    AppCase, Mock, mask_modules, skip_if_python3,
+    override_stdouts, patch, wrap_logger, restore_logging,
 )
 )
+from celery.tests._case import get_logger_handlers
 
 
 
 
 class test_TaskFormatter(AppCase):
 class test_TaskFormatter(AppCase):
@@ -155,10 +156,9 @@ class test_ColorFormatter(AppCase):
         self.assertIn('<Unrepresentable', msg)
         self.assertIn('<Unrepresentable', msg)
         self.assertEqual(safe_str.call_count, 1)
         self.assertEqual(safe_str.call_count, 1)
 
 
+    @skip_if_python3()
     @patch('celery.utils.log.safe_str')
     @patch('celery.utils.log.safe_str')
     def test_format_raises_no_color(self, safe_str):
     def test_format_raises_no_color(self, safe_str):
-        if sys.version_info[0] == 3:
-            raise SkipTest('py3k')
         x = ColorFormatter(use_color=False)
         x = ColorFormatter(use_color=False)
         record = Mock()
         record = Mock()
         record.levelname = 'ERROR'
         record.levelname = 'ERROR'
@@ -235,7 +235,7 @@ class test_default_logger(AppCase):
             logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
             logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
                                        root=False, colorize=None)
                                        root=False, colorize=None)
             self.assertIs(
             self.assertIs(
-                get_handlers(logger)[0].stream, sys.__stderr__,
+                get_logger_handlers(logger)[0].stream, sys.__stderr__,
                 'setup_logger logs to stderr without logfile argument.',
                 'setup_logger logs to stderr without logfile argument.',
             )
             )
 
 
@@ -273,7 +273,7 @@ class test_default_logger(AppCase):
                     logfile=tempfile, loglevel=logging.INFO, root=False,
                     logfile=tempfile, loglevel=logging.INFO, root=False,
                 )
                 )
                 self.assertIsInstance(
                 self.assertIsInstance(
-                    get_handlers(l)[0], logging.FileHandler,
+                    get_logger_handlers(l)[0], logging.FileHandler,
                 )
                 )
                 self.assertIn(tempfile, files)
                 self.assertIn(tempfile, files)
 
 

+ 3 - 6
celery/tests/app/test_schedules.py

@@ -10,7 +10,7 @@ from celery.five import items
 from celery.schedules import (
 from celery.schedules import (
     ParseException, crontab, crontab_parser, schedule, solar,
     ParseException, crontab, crontab_parser, schedule, solar,
 )
 )
-from celery.tests.case import AppCase, Mock, SkipTest
+from celery.tests.case import AppCase, Mock, skip_unless_module, todo
 
 
 
 
 @contextmanager
 @contextmanager
@@ -23,13 +23,10 @@ def patch_crontab_nowfun(cls, retval):
         cls.nowfun = prev_nowfun
         cls.nowfun = prev_nowfun
 
 
 
 
+@skip_unless_module('ephem')
 class test_solar(AppCase):
 class test_solar(AppCase):
 
 
     def setup(self):
     def setup(self):
-        try:
-            import ephem  # noqa
-        except ImportError:
-            raise SkipTest('ephem module not installed')
         self.s = solar('sunrise', 60, 30, app=self.app)
         self.s = solar('sunrise', 60, 30, app=self.app)
 
 
     def test_reduce(self):
     def test_reduce(self):
@@ -738,8 +735,8 @@ class test_crontab_is_due(AppCase):
             self.assertTrue(due)
             self.assertTrue(due)
             self.assertEqual(remaining, 60.)
             self.assertEqual(remaining, 60.)
 
 
+    @todo('unstable test')
     def test_monthly_moy_execution_is_not_due(self):
     def test_monthly_moy_execution_is_not_due(self):
-        raise SkipTest('unstable test')
         with patch_crontab_nowfun(
         with patch_crontab_nowfun(
                 self.monthly_moy, datetime(2013, 6, 28, 14, 30)):
                 self.monthly_moy, datetime(2013, 6, 28, 14, 30)):
             due, remaining = self.monthly_moy.is_due(
             due, remaining = self.monthly_moy.is_due(

+ 4 - 3
celery/tests/backends/test_base.py

@@ -25,7 +25,9 @@ from celery.result import result_from_tuple
 from celery.utils import uuid
 from celery.utils import uuid
 from celery.utils.functional import pass1
 from celery.utils.functional import pass1
 
 
-from celery.tests.case import ANY, AppCase, Case, Mock, SkipTest, call, patch
+from celery.tests.case import (
+    ANY, AppCase, Case, Mock, call, patch, skip_if_python3,
+)
 
 
 
 
 class wrapobject(object):
 class wrapobject(object):
@@ -92,9 +94,8 @@ class test_BaseBackend_interface(AppCase):
 
 
 class test_exception_pickle(AppCase):
 class test_exception_pickle(AppCase):
 
 
+    @skip_if_python3('does not support old style classes')
     def test_oldstyle(self):
     def test_oldstyle(self):
-        if Oldstyle is None:
-            raise SkipTest('py3k does not support old style classes')
         self.assertTrue(fnpe(Oldstyle()))
         self.assertTrue(fnpe(Oldstyle()))
 
 
     def test_BaseException(self):
     def test_BaseException(self):

+ 2 - 2
celery/tests/backends/test_cache.py

@@ -16,7 +16,7 @@ from celery.five import items, module_name_t, string, text_t
 from celery.utils import uuid
 from celery.utils import uuid
 
 
 from celery.tests.case import (
 from celery.tests.case import (
-    AppCase, Mock, disable_stdouts, mask_modules, patch, reset_modules,
+    AppCase, Mock, override_stdouts, mask_modules, patch, reset_modules,
 )
 )
 
 
 PY3 = sys.version_info[0] == 3
 PY3 = sys.version_info[0] == 3
@@ -136,7 +136,7 @@ class test_CacheBackend(AppCase):
         b = CacheBackend(backend=backend, app=self.app)
         b = CacheBackend(backend=backend, app=self.app)
         self.assertEqual(b.as_uri(), backend)
         self.assertEqual(b.as_uri(), backend)
 
 
-    @disable_stdouts
+    @override_stdouts
     def test_regression_worker_startup_info(self):
     def test_regression_worker_startup_info(self):
         self.app.conf.result_backend = (
         self.app.conf.result_backend = (
             'cache+memcached://127.0.0.1:11211;127.0.0.2:11211;127.0.0.3/'
             'cache+memcached://127.0.0.1:11211;127.0.0.2:11211;127.0.0.3/'

+ 2 - 52
celery/tests/backends/test_couchbase.py

@@ -9,7 +9,7 @@ from celery.backends.couchbase import CouchBaseBackend
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
 from celery import backends
 from celery import backends
 from celery.tests.case import (
 from celery.tests.case import (
-    AppCase, MagicMock, Mock, SkipTest, patch, sentinel,
+    AppCase, MagicMock, Mock, patch, sentinel, skip_unless_module,
 )
 )
 
 
 try:
 try:
@@ -20,24 +20,13 @@ except ImportError:
 COUCHBASE_BUCKET = 'celery_bucket'
 COUCHBASE_BUCKET = 'celery_bucket'
 
 
 
 
+@skip_unless_module('couchbase')
 class test_CouchBaseBackend(AppCase):
 class test_CouchBaseBackend(AppCase):
 
 
-    """CouchBaseBackend TestCase."""
-
     def setup(self):
     def setup(self):
-        """Skip the test if couchbase cannot be imported."""
-        if couchbase is None:
-            raise SkipTest('couchbase is not installed.')
         self.backend = CouchBaseBackend(app=self.app)
         self.backend = CouchBaseBackend(app=self.app)
 
 
     def test_init_no_couchbase(self):
     def test_init_no_couchbase(self):
-        """
-        Test init no couchbase raises.
-
-        If celery.backends.couchbase cannot import the couchbase client, it
-        sets the couchbase.Couchbase to None and then handles this in the
-        CouchBaseBackend __init__ method.
-        """
         prev, module.Couchbase = module.Couchbase, None
         prev, module.Couchbase = module.Couchbase, None
         try:
         try:
             with self.assertRaises(ImproperlyConfigured):
             with self.assertRaises(ImproperlyConfigured):
@@ -46,18 +35,15 @@ class test_CouchBaseBackend(AppCase):
             module.Couchbase = prev
             module.Couchbase = prev
 
 
     def test_init_no_settings(self):
     def test_init_no_settings(self):
-        """Test init no settings."""
         self.app.conf.couchbase_backend_settings = []
         self.app.conf.couchbase_backend_settings = []
         with self.assertRaises(ImproperlyConfigured):
         with self.assertRaises(ImproperlyConfigured):
             CouchBaseBackend(app=self.app)
             CouchBaseBackend(app=self.app)
 
 
     def test_init_settings_is_None(self):
     def test_init_settings_is_None(self):
-        """Test init settings is None."""
         self.app.conf.couchbase_backend_settings = None
         self.app.conf.couchbase_backend_settings = None
         CouchBaseBackend(app=self.app)
         CouchBaseBackend(app=self.app)
 
 
     def test_get_connection_connection_exists(self):
     def test_get_connection_connection_exists(self):
-        """Test _get_connection works."""
         with patch('couchbase.connection.Connection') as mock_Connection:
         with patch('couchbase.connection.Connection') as mock_Connection:
             self.backend._connection = sentinel._connection
             self.backend._connection = sentinel._connection
 
 
@@ -67,14 +53,6 @@ class test_CouchBaseBackend(AppCase):
             self.assertFalse(mock_Connection.called)
             self.assertFalse(mock_Connection.called)
 
 
     def test_get(self):
     def test_get(self):
-        """
-        Test get method.
-
-        CouchBaseBackend.get should return  and take two params
-        db conn to couchbase is mocked.
-
-        TODO Should test on key not exists
-        """
         self.app.conf.couchbase_backend_settings = {}
         self.app.conf.couchbase_backend_settings = {}
         x = CouchBaseBackend(app=self.app)
         x = CouchBaseBackend(app=self.app)
         x._connection = Mock()
         x._connection = Mock()
@@ -85,12 +63,6 @@ class test_CouchBaseBackend(AppCase):
         x._connection.get.assert_called_once_with('1f3fab')
         x._connection.get.assert_called_once_with('1f3fab')
 
 
     def test_set(self):
     def test_set(self):
-        """
-        Test set method.
-
-        CouchBaseBackend.set should return None and take two params
-        db conn to couchbase is mocked.
-        """
         self.app.conf.couchbase_backend_settings = None
         self.app.conf.couchbase_backend_settings = None
         x = CouchBaseBackend(app=self.app)
         x = CouchBaseBackend(app=self.app)
         x._connection = MagicMock()
         x._connection = MagicMock()
@@ -99,14 +71,6 @@ class test_CouchBaseBackend(AppCase):
         self.assertIsNone(x.set(sentinel.key, sentinel.value))
         self.assertIsNone(x.set(sentinel.key, sentinel.value))
 
 
     def test_delete(self):
     def test_delete(self):
-        """
-        Test delete method.
-
-        CouchBaseBackend.delete should return and take two params
-        db conn to couchbase is mocked.
-
-        TODO Should test on key not exists.
-        """
         self.app.conf.couchbase_backend_settings = {}
         self.app.conf.couchbase_backend_settings = {}
         x = CouchBaseBackend(app=self.app)
         x = CouchBaseBackend(app=self.app)
         x._connection = Mock()
         x._connection = Mock()
@@ -117,11 +81,6 @@ class test_CouchBaseBackend(AppCase):
         x._connection.delete.assert_called_once_with('1f3fab')
         x._connection.delete.assert_called_once_with('1f3fab')
 
 
     def test_config_params(self):
     def test_config_params(self):
-        """
-        Test config params are correct.
-
-        app.conf.couchbase_backend_settings is properly set.
-        """
         self.app.conf.couchbase_backend_settings = {
         self.app.conf.couchbase_backend_settings = {
             'bucket': 'mycoolbucket',
             'bucket': 'mycoolbucket',
             'host': ['here.host.com', 'there.host.com'],
             'host': ['here.host.com', 'there.host.com'],
@@ -137,14 +96,12 @@ class test_CouchBaseBackend(AppCase):
         self.assertEqual(x.port, 1234)
         self.assertEqual(x.port, 1234)
 
 
     def test_backend_by_url(self, url='couchbase://myhost/mycoolbucket'):
     def test_backend_by_url(self, url='couchbase://myhost/mycoolbucket'):
-        """Test that a CouchBaseBackend is loaded from the couchbase url."""
         from celery.backends.couchbase import CouchBaseBackend
         from celery.backends.couchbase import CouchBaseBackend
         backend, url_ = backends.get_backend_by_url(url, self.app.loader)
         backend, url_ = backends.get_backend_by_url(url, self.app.loader)
         self.assertIs(backend, CouchBaseBackend)
         self.assertIs(backend, CouchBaseBackend)
         self.assertEqual(url_, url)
         self.assertEqual(url_, url)
 
 
     def test_backend_params_by_url(self):
     def test_backend_params_by_url(self):
-        """Test config params are correct from config url."""
         url = 'couchbase://johndoe:mysecret@myhost:123/mycoolbucket'
         url = 'couchbase://johndoe:mysecret@myhost:123/mycoolbucket'
         with self.Celery(backend=url) as app:
         with self.Celery(backend=url) as app:
             x = app.backend
             x = app.backend
@@ -155,13 +112,6 @@ class test_CouchBaseBackend(AppCase):
             self.assertEqual(x.port, 123)
             self.assertEqual(x.port, 123)
 
 
     def test_correct_key_types(self):
     def test_correct_key_types(self):
-        """
-        Test that the key is the correct type for the couchbase python API.
-
-        We check that get_key_for_task, get_key_for_chord, and
-        get_key_for_group always returns a python string. Need to use str_t
-        for cross Python reasons.
-        """
         keys = [
         keys = [
             self.backend.get_key_for_task('task_id', bytes('key')),
             self.backend.get_key_for_task('task_id', bytes('key')),
             self.backend.get_key_for_chord('group_id', bytes('key')),
             self.backend.get_key_for_chord('group_id', bytes('key')),

+ 2 - 3
celery/tests/backends/test_couchdb.py

@@ -5,7 +5,7 @@ from celery.backends.couchdb import CouchBackend
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
 from celery import backends
 from celery import backends
 from celery.tests.case import (
 from celery.tests.case import (
-    AppCase, Mock, SkipTest, patch, sentinel,
+    AppCase, Mock, patch, sentinel, skip_unless_module,
 )
 )
 
 
 try:
 try:
@@ -16,11 +16,10 @@ except ImportError:
 COUCHDB_CONTAINER = 'celery_container'
 COUCHDB_CONTAINER = 'celery_container'
 
 
 
 
+@skip_unless_module('pycouchdb')
 class test_CouchBackend(AppCase):
 class test_CouchBackend(AppCase):
 
 
     def setup(self):
     def setup(self):
-        if pycouchdb is None:
-            raise SkipTest('pycouchdb is not installed.')
         self.backend = CouchBackend(app=self.app)
         self.backend = CouchBackend(app=self.app)
 
 
     def test_init_no_pycouchdb(self):
     def test_init_no_pycouchdb(self):

+ 6 - 13
celery/tests/backends/test_database.py

@@ -11,11 +11,11 @@ from celery.utils import uuid
 from celery.tests.case import (
 from celery.tests.case import (
     AppCase,
     AppCase,
     Mock,
     Mock,
-    SkipTest,
     depends_on_current_app,
     depends_on_current_app,
     patch,
     patch,
     skip_if_pypy,
     skip_if_pypy,
     skip_if_jython,
     skip_if_jython,
+    skip_unless_module,
 )
 )
 
 
 try:
 try:
@@ -38,12 +38,9 @@ class SomeClass(object):
         self.data = data
         self.data = data
 
 
 
 
+@skip_unless_module('sqlalchemy')
 class test_session_cleanup(AppCase):
 class test_session_cleanup(AppCase):
 
 
-    def setup(self):
-        if session_cleanup is None:
-            raise SkipTest('slqlalchemy not installed')
-
     def test_context(self):
     def test_context(self):
         session = Mock(name='session')
         session = Mock(name='session')
         with session_cleanup(session):
         with session_cleanup(session):
@@ -59,13 +56,12 @@ class test_session_cleanup(AppCase):
         session.close.assert_called_with()
         session.close.assert_called_with()
 
 
 
 
+@skip_unless_module('sqlalchemy')
+@skip_if_pypy()
+@skip_if_jython()
 class test_DatabaseBackend(AppCase):
 class test_DatabaseBackend(AppCase):
 
 
-    @skip_if_pypy
-    @skip_if_jython
     def setup(self):
     def setup(self):
-        if DatabaseBackend is None:
-            raise SkipTest('sqlalchemy not installed')
         self.uri = 'sqlite:///test.db'
         self.uri = 'sqlite:///test.db'
         self.app.conf.result_serializer = 'pickle'
         self.app.conf.result_serializer = 'pickle'
 
 
@@ -218,12 +214,9 @@ class test_DatabaseBackend(AppCase):
         self.assertIn('foo', repr(TaskSet('foo', None)))
         self.assertIn('foo', repr(TaskSet('foo', None)))
 
 
 
 
+@skip_unless_module('sqlalchemy')
 class test_SessionManager(AppCase):
 class test_SessionManager(AppCase):
 
 
-    def setup(self):
-        if SessionManager is None:
-            raise SkipTest('sqlalchemy not installed')
-
     def test_after_fork(self):
     def test_after_fork(self):
         s = SessionManager()
         s = SessionManager()
         self.assertFalse(s.forked)
         self.assertFalse(s.forked)

+ 2 - 8
celery/tests/backends/test_elasticsearch.py

@@ -5,19 +5,13 @@ from celery.backends import elasticsearch as module
 from celery.backends.elasticsearch import ElasticsearchBackend
 from celery.backends.elasticsearch import ElasticsearchBackend
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
 
 
-from celery.tests.case import AppCase, Mock, SkipTest, sentinel
-
-try:
-    import elasticsearch
-except ImportError:
-    elasticsearch = None
+from celery.tests.case import AppCase, Mock, sentinel, skip_unless_module
 
 
 
 
+@skip_unless_module('elasticsearch')
 class test_ElasticsearchBackend(AppCase):
 class test_ElasticsearchBackend(AppCase):
 
 
     def setup(self):
     def setup(self):
-        if elasticsearch is None:
-            raise SkipTest('elasticsearch is not installed.')
         self.backend = ElasticsearchBackend(app=self.app)
         self.backend = ElasticsearchBackend(app=self.app)
 
 
     def test_init_no_elasticsearch(self):
     def test_init_no_elasticsearch(self):

+ 2 - 4
celery/tests/backends/test_filesystem.py

@@ -3,7 +3,6 @@ from __future__ import absolute_import, unicode_literals
 
 
 import os
 import os
 import shutil
 import shutil
-import sys
 import tempfile
 import tempfile
 
 
 from celery import states
 from celery import states
@@ -11,14 +10,13 @@ from celery.backends.filesystem import FilesystemBackend
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
 from celery.utils import uuid
 from celery.utils import uuid
 
 
-from celery.tests.case import AppCase, SkipTest
+from celery.tests.case import AppCase, skip_if_win32
 
 
 
 
+@skip_if_win32()
 class test_FilesystemBackend(AppCase):
 class test_FilesystemBackend(AppCase):
 
 
     def setup(self):
     def setup(self):
-        if sys.platform == 'win32':
-            raise SkipTest('win32: skip')
         self.directory = tempfile.mkdtemp()
         self.directory = tempfile.mkdtemp()
         self.url = 'file://' + self.directory
         self.url = 'file://' + self.directory
         self.path = self.directory.encode('ascii')
         self.path = self.directory.encode('ascii')

+ 7 - 13
celery/tests/backends/test_mongodb.py

@@ -9,13 +9,12 @@ from kombu.exceptions import EncodeError
 from celery import uuid
 from celery import uuid
 from celery import states
 from celery import states
 from celery.backends import mongodb as module
 from celery.backends import mongodb as module
-from celery.backends.mongodb import (
-    InvalidDocument, MongoBackend, pymongo,
-)
+from celery.backends.mongodb import InvalidDocument, MongoBackend
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
 from celery.tests.case import (
 from celery.tests.case import (
-    AppCase, MagicMock, Mock, SkipTest, ANY,
-    depends_on_current_app, disable_stdouts, patch, sentinel,
+    AppCase, MagicMock, Mock, ANY,
+    depends_on_current_app, override_stdouts, patch, sentinel,
+    skip_unless_module,
 )
 )
 
 
 COLLECTION = 'taskmeta_celery'
 COLLECTION = 'taskmeta_celery'
@@ -29,6 +28,7 @@ MONGODB_COLLECTION = 'collection1'
 MONGODB_GROUP_COLLECTION = 'group_collection1'
 MONGODB_GROUP_COLLECTION = 'group_collection1'
 
 
 
 
+@skip_unless_module('pymongo')
 class test_MongoBackend(AppCase):
 class test_MongoBackend(AppCase):
 
 
     default_url = 'mongodb://uuuu:pwpw@hostname.dom/database'
     default_url = 'mongodb://uuuu:pwpw@hostname.dom/database'
@@ -43,9 +43,6 @@ class test_MongoBackend(AppCase):
     )
     )
 
 
     def setup(self):
     def setup(self):
-        if pymongo is None:
-            raise SkipTest('pymongo is not installed.')
-
         R = self._reset = {}
         R = self._reset = {}
         R['encode'], MongoBackend.encode = MongoBackend.encode, Mock()
         R['encode'], MongoBackend.encode = MongoBackend.encode, Mock()
         R['decode'], MongoBackend.decode = MongoBackend.decode, Mock()
         R['decode'], MongoBackend.decode = MongoBackend.decode, Mock()
@@ -410,7 +407,7 @@ class test_MongoBackend(AppCase):
         backend = MongoBackend(app=self.app, url=self.replica_set_url)
         backend = MongoBackend(app=self.app, url=self.replica_set_url)
         self.assertEqual(backend.as_uri(), self.sanitized_replica_set_url)
         self.assertEqual(backend.as_uri(), self.sanitized_replica_set_url)
 
 
-    @disable_stdouts
+    @override_stdouts
     def test_regression_worker_startup_info(self):
     def test_regression_worker_startup_info(self):
         self.app.conf.result_backend = (
         self.app.conf.result_backend = (
             'mongodb://user:password@host0.com:43437,host1.com:43437'
             'mongodb://user:password@host0.com:43437,host1.com:43437'
@@ -421,12 +418,9 @@ class test_MongoBackend(AppCase):
         self.assertTrue(worker.startup_info())
         self.assertTrue(worker.startup_info())
 
 
 
 
+@skip_unless_module('pymongo')
 class test_MongoBackend_no_mock(AppCase):
 class test_MongoBackend_no_mock(AppCase):
 
 
-    def setup(self):
-        if pymongo is None:
-            raise SkipTest('pymongo is not installed.')
-
     def test_encode_decode(self):
     def test_encode_decode(self):
         backend = MongoBackend(app=self.app)
         backend = MongoBackend(app=self.app)
         data = {'foo': 1}
         data = {'foo': 1}

+ 6 - 8
celery/tests/backends/test_redis.py

@@ -13,8 +13,8 @@ from celery.datastructures import AttributeDict
 from celery.exceptions import ChordError, ImproperlyConfigured
 from celery.exceptions import ChordError, ImproperlyConfigured
 
 
 from celery.tests.case import (
 from celery.tests.case import (
-    ANY, AppCase, ContextMock, Mock, MockCallbacks, SkipTest,
-    call, depends_on_current_app, patch,
+    ANY, AppCase, ContextMock, Mock, MockCallbacks,
+    call, depends_on_current_app, patch, skip_unless_module,
 )
 )
 
 
 
 
@@ -142,13 +142,11 @@ class test_RedisBackend(AppCase):
         self.b = self.Backend(app=self.app)
         self.b = self.Backend(app=self.app)
 
 
     @depends_on_current_app
     @depends_on_current_app
+    @skip_unless_module('redis')
     def test_reduce(self):
     def test_reduce(self):
-        try:
-            from celery.backends.redis import RedisBackend
-            x = RedisBackend(app=self.app)
-            self.assertTrue(loads(dumps(x)))
-        except ImportError:
-            raise SkipTest('redis not installed')
+        from celery.backends.redis import RedisBackend
+        x = RedisBackend(app=self.app)
+        self.assertTrue(loads(dumps(x)))
 
 
     def test_no_redis(self):
     def test_no_redis(self):
         self.Backend.redis = None
         self.Backend.redis = None

+ 3 - 4
celery/tests/backends/test_riak.py

@@ -3,21 +3,20 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
 from celery.backends import riak as module
 from celery.backends import riak as module
-from celery.backends.riak import RiakBackend, riak
+from celery.backends.riak import RiakBackend
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
 from celery.tests.case import (
 from celery.tests.case import (
-    AppCase, MagicMock, Mock, SkipTest, patch, sentinel,
+    AppCase, MagicMock, Mock, patch, sentinel, skip_unless_module,
 )
 )
 
 
 
 
 RIAK_BUCKET = 'riak_bucket'
 RIAK_BUCKET = 'riak_bucket'
 
 
 
 
+@skip_unless_module('riak')
 class test_RiakBackend(AppCase):
 class test_RiakBackend(AppCase):
 
 
     def setup(self):
     def setup(self):
-        if riak is None:
-            raise SkipTest('riak is not installed.')
         self.app.conf.result_backend = 'riak://'
         self.app.conf.result_backend = 'riak://'
 
 
     @property
     @property

+ 2 - 1
celery/tests/bin/test_amqp.py

@@ -7,8 +7,9 @@ from celery.bin.amqp import (
     amqp,
     amqp,
     main,
     main,
 )
 )
+from celery.five import WhateverIO
 
 
-from celery.tests.case import AppCase, Mock, WhateverIO, patch
+from celery.tests.case import AppCase, Mock, patch
 
 
 
 
 class test_AMQShell(AppCase):
 class test_AMQShell(AppCase):

+ 3 - 2
celery/tests/bin/test_celery.py

@@ -7,7 +7,6 @@ from datetime import datetime
 from kombu.utils.json import dumps
 from kombu.utils.json import dumps
 
 
 from celery import __main__
 from celery import __main__
-from celery.platforms import EX_FAILURE, EX_USAGE, EX_OK
 from celery.bin.base import Error
 from celery.bin.base import Error
 from celery.bin import celery as mod
 from celery.bin import celery as mod
 from celery.bin.celery import (
 from celery.bin.celery import (
@@ -29,8 +28,10 @@ from celery.bin.celery import (
     _RemoteControl,
     _RemoteControl,
     command,
     command,
 )
 )
+from celery.five import WhateverIO
+from celery.platforms import EX_FAILURE, EX_USAGE, EX_OK
 
 
-from celery.tests.case import AppCase, Mock, WhateverIO, patch
+from celery.tests.case import AppCase, Mock, patch
 
 
 
 
 class test__main__(AppCase):
 class test__main__(AppCase):

+ 2 - 1
celery/tests/bin/test_celeryevdump.py

@@ -7,8 +7,9 @@ from celery.events.dumper import (
     Dumper,
     Dumper,
     evdump,
     evdump,
 )
 )
+from celery.five import WhateverIO
 
 
-from celery.tests.case import AppCase, Mock, WhateverIO, patch
+from celery.tests.case import AppCase, Mock, patch
 
 
 
 
 class test_Dumper(AppCase):
 class test_Dumper(AppCase):

+ 2 - 6
celery/tests/bin/test_events.py

@@ -2,7 +2,7 @@ from __future__ import absolute_import, unicode_literals
 
 
 from celery.bin import events
 from celery.bin import events
 
 
-from celery.tests.case import AppCase, SkipTest, patch, _old_patch
+from celery.tests.case import AppCase, patch, _old_patch, skip_unless_module
 
 
 
 
 class MockCommand(object):
 class MockCommand(object):
@@ -29,12 +29,8 @@ class test_events(AppCase):
         self.assertEqual(self.ev.run(dump=True), 'me dumper, you?')
         self.assertEqual(self.ev.run(dump=True), 'me dumper, you?')
         self.assertIn('celery events:dump', proctitle.last[0])
         self.assertIn('celery events:dump', proctitle.last[0])
 
 
+    @skip_unless_module('curses')
     def test_run_top(self):
     def test_run_top(self):
-        try:
-            import curses  # noqa
-        except (ImportError, OSError):
-            raise SkipTest('curses monitor requires curses')
-
         @_old_patch('celery.events.cursesmon', 'evtop',
         @_old_patch('celery.events.cursesmon', 'evtop',
                     lambda **kw: 'me top, you?')
                     lambda **kw: 'me top, you?')
         @_old_patch('celery.bin.events', 'set_process_title', proctitle)
         @_old_patch('celery.bin.events', 'set_process_title', proctitle)

+ 3 - 3
celery/tests/bin/test_multi.py

@@ -15,8 +15,9 @@ from celery.bin.multi import (
     multi_args,
     multi_args,
     __doc__ as doc,
     __doc__ as doc,
 )
 )
+from celery.five import WhateverIO
 
 
-from celery.tests.case import AppCase, Mock, WhateverIO, SkipTest, patch
+from celery.tests.case import AppCase, Mock, patch, skip_unless_symbol
 
 
 
 
 class test_functions(AppCase):
 class test_functions(AppCase):
@@ -264,9 +265,8 @@ class test_MultiTool(AppCase):
 
 
         )
         )
 
 
+    @skip_unless_symbol('signal.SIGKILL')
     def test_kill(self):
     def test_kill(self):
-        if not hasattr(signal, 'SIGKILL'):
-            raise SkipTest('SIGKILL not supported by this platform')
         self.t.getpids = Mock()
         self.t.getpids = Mock()
         self.t.getpids.return_value = [
         self.t.getpids.return_value = [
             ('a', None, 10),
             ('a', None, 10),

+ 35 - 61
celery/tests/bin/test_worker.py

@@ -21,18 +21,19 @@ from celery.worker import state
 from celery.tests.case import (
 from celery.tests.case import (
     AppCase,
     AppCase,
     Mock,
     Mock,
-    SkipTest,
-    disable_stdouts,
+    override_stdouts,
     patch,
     patch,
-    skip_if_pypy,
     skip_if_jython,
     skip_if_jython,
+    skip_if_pypy,
+    skip_if_win32,
+    skip_unless_module,
+    skip_unless_symbol,
 )
 )
 
 
 
 
 class WorkerAppCase(AppCase):
 class WorkerAppCase(AppCase):
 
 
-    def tearDown(self):
-        super(WorkerAppCase, self).tearDown()
+    def teardown(self):
         trace.reset_worker_optimizations()
         trace.reset_worker_optimizations()
 
 
 
 
@@ -46,13 +47,13 @@ class Worker(cd.Worker):
 class test_Worker(WorkerAppCase):
 class test_Worker(WorkerAppCase):
     Worker = Worker
     Worker = Worker
 
 
-    @disable_stdouts
+    @override_stdouts
     def test_queues_string(self):
     def test_queues_string(self):
         w = self.app.Worker()
         w = self.app.Worker()
         w.setup_queues('foo,bar,baz')
         w.setup_queues('foo,bar,baz')
         self.assertIn('foo', self.app.amqp.queues)
         self.assertIn('foo', self.app.amqp.queues)
 
 
-    @disable_stdouts
+    @override_stdouts
     def test_cpu_count(self):
     def test_cpu_count(self):
         with patch('celery.worker.cpu_count') as cpu_count:
         with patch('celery.worker.cpu_count') as cpu_count:
             cpu_count.side_effect = NotImplementedError()
             cpu_count.side_effect = NotImplementedError()
@@ -61,7 +62,7 @@ class test_Worker(WorkerAppCase):
         w = self.app.Worker(concurrency=5)
         w = self.app.Worker(concurrency=5)
         self.assertEqual(w.concurrency, 5)
         self.assertEqual(w.concurrency, 5)
 
 
-    @disable_stdouts
+    @override_stdouts
     def test_windows_B_option(self):
     def test_windows_B_option(self):
         self.app.IS_WINDOWS = True
         self.app.IS_WINDOWS = True
         with self.assertRaises(SystemExit):
         with self.assertRaises(SystemExit):
@@ -93,7 +94,7 @@ class test_Worker(WorkerAppCase):
                 x.maybe_detach(['--detach'])
                 x.maybe_detach(['--detach'])
             self.assertTrue(detached.called)
             self.assertTrue(detached.called)
 
 
-    @disable_stdouts
+    @override_stdouts
     def test_invalid_loglevel_gives_error(self):
     def test_invalid_loglevel_gives_error(self):
         x = worker(app=self.app)
         x = worker(app=self.app)
         with self.assertRaises(SystemExit):
         with self.assertRaises(SystemExit):
@@ -117,12 +118,12 @@ class test_Worker(WorkerAppCase):
         worker.loglevel = logging.INFO
         worker.loglevel = logging.INFO
         self.assertTrue(worker.extra_info())
         self.assertTrue(worker.extra_info())
 
 
-    @disable_stdouts
+    @override_stdouts
     def test_loglevel_string(self):
     def test_loglevel_string(self):
         worker = self.Worker(app=self.app, loglevel='INFO')
         worker = self.Worker(app=self.app, loglevel='INFO')
         self.assertEqual(worker.loglevel, logging.INFO)
         self.assertEqual(worker.loglevel, logging.INFO)
 
 
-    @disable_stdouts
+    @override_stdouts
     def test_run_worker(self):
     def test_run_worker(self):
         handlers = {}
         handlers = {}
 
 
@@ -150,7 +151,7 @@ class test_Worker(WorkerAppCase):
         finally:
         finally:
             platforms.signals = p
             platforms.signals = p
 
 
-    @disable_stdouts
+    @override_stdouts
     def test_startup_info(self):
     def test_startup_info(self):
         worker = self.Worker(app=self.app)
         worker = self.Worker(app=self.app)
         worker.on_start()
         worker.on_start()
@@ -188,18 +189,18 @@ class test_Worker(WorkerAppCase):
         finally:
         finally:
             cd.ARTLINES = prev
             cd.ARTLINES = prev
 
 
-    @disable_stdouts
+    @override_stdouts
     def test_run(self):
     def test_run(self):
         self.Worker(app=self.app).on_start()
         self.Worker(app=self.app).on_start()
         self.Worker(app=self.app, purge=True).on_start()
         self.Worker(app=self.app, purge=True).on_start()
         worker = self.Worker(app=self.app)
         worker = self.Worker(app=self.app)
         worker.on_start()
         worker.on_start()
 
 
-    @disable_stdouts
+    @override_stdouts
     def test_purge_messages(self):
     def test_purge_messages(self):
         self.Worker(app=self.app).purge_messages()
         self.Worker(app=self.app).purge_messages()
 
 
-    @disable_stdouts
+    @override_stdouts
     def test_init_queues(self):
     def test_init_queues(self):
         app = self.app
         app = self.app
         c = app.conf
         c = app.conf
@@ -230,7 +231,7 @@ class test_Worker(WorkerAppCase):
             app.amqp.queues['image'],
             app.amqp.queues['image'],
         )
         )
 
 
-    @disable_stdouts
+    @override_stdouts
     def test_autoscale_argument(self):
     def test_autoscale_argument(self):
         worker1 = self.Worker(app=self.app, autoscale='10,3')
         worker1 = self.Worker(app=self.app, autoscale='10,3')
         self.assertListEqual(worker1.autoscale, [10, 3])
         self.assertListEqual(worker1.autoscale, [10, 3])
@@ -246,20 +247,17 @@ class test_Worker(WorkerAppCase):
         self.assertListEqual(worker2.include, ['os', 'sys'])
         self.assertListEqual(worker2.include, ['os', 'sys'])
         self.Worker(app=self.app, include=['os', 'sys'])
         self.Worker(app=self.app, include=['os', 'sys'])
 
 
-    @disable_stdouts
+    @override_stdouts
     def test_unknown_loglevel(self):
     def test_unknown_loglevel(self):
         with self.assertRaises(SystemExit):
         with self.assertRaises(SystemExit):
             worker(app=self.app).run(loglevel='ALIEN')
             worker(app=self.app).run(loglevel='ALIEN')
         worker1 = self.Worker(app=self.app, loglevel=0xFFFF)
         worker1 = self.Worker(app=self.app, loglevel=0xFFFF)
         self.assertEqual(worker1.loglevel, 0xFFFF)
         self.assertEqual(worker1.loglevel, 0xFFFF)
 
 
-    @disable_stdouts
+    @skip_if_win32
+    @override_stdouts
     @patch('os._exit')
     @patch('os._exit')
     def test_warns_if_running_as_privileged_user(self, _exit):
     def test_warns_if_running_as_privileged_user(self, _exit):
-        app = self.app
-        if app.IS_WINDOWS:
-            raise SkipTest('Not applicable on Windows')
-
         with patch('os.getuid') as getuid:
         with patch('os.getuid') as getuid:
             getuid.return_value = 0
             getuid.return_value = 0
             self.app.conf.accept_content = ['pickle']
             self.app.conf.accept_content = ['pickle']
@@ -283,13 +281,13 @@ class test_Worker(WorkerAppCase):
                 worker = self.Worker(app=self.app)
                 worker = self.Worker(app=self.app)
                 worker.on_start()
                 worker.on_start()
 
 
-    @disable_stdouts
+    @override_stdouts
     def test_redirect_stdouts(self):
     def test_redirect_stdouts(self):
         self.Worker(app=self.app, redirect_stdouts=False)
         self.Worker(app=self.app, redirect_stdouts=False)
         with self.assertRaises(AttributeError):
         with self.assertRaises(AttributeError):
             sys.stdout.logger
             sys.stdout.logger
 
 
-    @disable_stdouts
+    @override_stdouts
     def test_on_start_custom_logging(self):
     def test_on_start_custom_logging(self):
         self.app.log.redirect_stdouts = Mock()
         self.app.log.redirect_stdouts = Mock()
         worker = self.Worker(app=self.app, redirect_stoutds=True)
         worker = self.Worker(app=self.app, redirect_stoutds=True)
@@ -308,7 +306,7 @@ class test_Worker(WorkerAppCase):
         finally:
         finally:
             self.app.log.setup = prev
             self.app.log.setup = prev
 
 
-    @disable_stdouts
+    @override_stdouts
     def test_startup_info_pool_is_str(self):
     def test_startup_info_pool_is_str(self):
         worker = self.Worker(app=self.app, redirect_stdouts=False)
         worker = self.Worker(app=self.app, redirect_stdouts=False)
         worker.pool_cls = 'foo'
         worker.pool_cls = 'foo'
@@ -331,7 +329,7 @@ class test_Worker(WorkerAppCase):
         finally:
         finally:
             signals.setup_logging.disconnect(on_logging_setup)
             signals.setup_logging.disconnect(on_logging_setup)
 
 
-    @disable_stdouts
+    @override_stdouts
     def test_platform_tweaks_osx(self):
     def test_platform_tweaks_osx(self):
 
 
         class OSXWorker(Worker):
         class OSXWorker(Worker):
@@ -359,7 +357,7 @@ class test_Worker(WorkerAppCase):
         finally:
         finally:
             cd.install_HUP_not_supported_handler = prev
             cd.install_HUP_not_supported_handler = prev
 
 
-    @disable_stdouts
+    @override_stdouts
     def test_general_platform_tweaks(self):
     def test_general_platform_tweaks(self):
 
 
         restart_worker_handler_installed = [False]
         restart_worker_handler_installed = [False]
@@ -380,7 +378,7 @@ class test_Worker(WorkerAppCase):
         finally:
         finally:
             cd.install_worker_restart_handler = prev
             cd.install_worker_restart_handler = prev
 
 
-    @disable_stdouts
+    @override_stdouts
     def test_on_consumer_ready(self):
     def test_on_consumer_ready(self):
         worker_ready_sent = [False]
         worker_ready_sent = [False]
 
 
@@ -392,17 +390,14 @@ class test_Worker(WorkerAppCase):
         self.assertTrue(worker_ready_sent[0])
         self.assertTrue(worker_ready_sent[0])
 
 
 
 
+@override_stdouts
 class test_funs(WorkerAppCase):
 class test_funs(WorkerAppCase):
 
 
     def test_active_thread_count(self):
     def test_active_thread_count(self):
         self.assertTrue(cd.active_thread_count())
         self.assertTrue(cd.active_thread_count())
 
 
-    @disable_stdouts
+    @skip_unless_module('setproctitle')
     def test_set_process_status(self):
     def test_set_process_status(self):
-        try:
-            __import__('setproctitle')
-        except ImportError:
-            raise SkipTest('setproctitle not installed')
         worker = Worker(app=self.app, hostname='xyzza')
         worker = Worker(app=self.app, hostname='xyzza')
         prev1, sys.argv = sys.argv, ['Arg0']
         prev1, sys.argv = sys.argv, ['Arg0']
         try:
         try:
@@ -422,7 +417,6 @@ class test_funs(WorkerAppCase):
         finally:
         finally:
             sys.argv = prev1
             sys.argv = prev1
 
 
-    @disable_stdouts
     def test_parse_options(self):
     def test_parse_options(self):
         cmd = worker()
         cmd = worker()
         cmd.app = self.app
         cmd.app = self.app
@@ -431,7 +425,6 @@ class test_funs(WorkerAppCase):
         self.assertEqual(opts.concurrency, 512)
         self.assertEqual(opts.concurrency, 512)
         self.assertEqual(opts.heartbeat_interval, 10)
         self.assertEqual(opts.heartbeat_interval, 10)
 
 
-    @disable_stdouts
     def test_main(self):
     def test_main(self):
         p, cd.Worker = cd.Worker, Worker
         p, cd.Worker = cd.Worker, Worker
         s, sys.argv = sys.argv, ['worker', '--discard']
         s, sys.argv = sys.argv, ['worker', '--discard']
@@ -442,6 +435,7 @@ class test_funs(WorkerAppCase):
             sys.argv = s
             sys.argv = s
 
 
 
 
+@override_stdouts
 class test_signal_handlers(WorkerAppCase):
 class test_signal_handlers(WorkerAppCase):
 
 
     class _Worker(object):
     class _Worker(object):
@@ -468,7 +462,6 @@ class test_signal_handlers(WorkerAppCase):
         finally:
         finally:
             platforms.signals = p
             platforms.signals = p
 
 
-    @disable_stdouts
     def test_worker_int_handler(self):
     def test_worker_int_handler(self):
         worker = self._Worker()
         worker = self._Worker()
         handlers = self.psig(cd.install_worker_int_handler, worker)
         handlers = self.psig(cd.install_worker_int_handler, worker)
@@ -511,12 +504,8 @@ class test_signal_handlers(WorkerAppCase):
             with self.assertRaises(WorkerTerminate):
             with self.assertRaises(WorkerTerminate):
                 next_handlers['SIGINT']('SIGINT', object())
                 next_handlers['SIGINT']('SIGINT', object())
 
 
-    @disable_stdouts
+    @skip_unless_module('multiprocessing')
     def test_worker_int_handler_only_stop_MainProcess(self):
     def test_worker_int_handler_only_stop_MainProcess(self):
-        try:
-            import _multiprocessing  # noqa
-        except ImportError:
-            raise SkipTest('only relevant for multiprocessing')
         process = current_process()
         process = current_process()
         name, process.name = process.name, 'OtherProcess'
         name, process.name = process.name, 'OtherProcess'
         with patch('celery.apps.worker.active_thread_count') as c:
         with patch('celery.apps.worker.active_thread_count') as c:
@@ -541,18 +530,13 @@ class test_signal_handlers(WorkerAppCase):
                 process.name = name
                 process.name = name
                 state.should_stop = None
                 state.should_stop = None
 
 
-    @disable_stdouts
     def test_install_HUP_not_supported_handler(self):
     def test_install_HUP_not_supported_handler(self):
         worker = self._Worker()
         worker = self._Worker()
         handlers = self.psig(cd.install_HUP_not_supported_handler, worker)
         handlers = self.psig(cd.install_HUP_not_supported_handler, worker)
         handlers['SIGHUP']('SIGHUP', object())
         handlers['SIGHUP']('SIGHUP', object())
 
 
-    @disable_stdouts
+    @skip_unless_module('multiprocessing')
     def test_worker_term_hard_handler_only_stop_MainProcess(self):
     def test_worker_term_hard_handler_only_stop_MainProcess(self):
-        try:
-            import _multiprocessing  # noqa
-        except ImportError:
-            raise SkipTest('only relevant for multiprocessing')
         process = current_process()
         process = current_process()
         name, process.name = process.name, 'OtherProcess'
         name, process.name = process.name, 'OtherProcess'
         try:
         try:
@@ -579,7 +563,6 @@ class test_signal_handlers(WorkerAppCase):
         finally:
         finally:
             process.name = name
             process.name = name
 
 
-    @disable_stdouts
     def test_worker_term_handler_when_threads(self):
     def test_worker_term_handler_when_threads(self):
         with patch('celery.apps.worker.active_thread_count') as c:
         with patch('celery.apps.worker.active_thread_count') as c:
             c.return_value = 3
             c.return_value = 3
@@ -591,7 +574,6 @@ class test_signal_handlers(WorkerAppCase):
             finally:
             finally:
                 state.should_stop = None
                 state.should_stop = None
 
 
-    @disable_stdouts
     def test_worker_term_handler_when_single_thread(self):
     def test_worker_term_handler_when_single_thread(self):
         with patch('celery.apps.worker.active_thread_count') as c:
         with patch('celery.apps.worker.active_thread_count') as c:
             c.return_value = 1
             c.return_value = 1
@@ -604,19 +586,15 @@ class test_signal_handlers(WorkerAppCase):
                 state.should_stop = None
                 state.should_stop = None
 
 
     @patch('sys.__stderr__')
     @patch('sys.__stderr__')
-    @skip_if_pypy
-    @skip_if_jython
+    @skip_if_pypy()
+    @skip_if_jython()
     def test_worker_cry_handler(self, stderr):
     def test_worker_cry_handler(self, stderr):
         handlers = self.psig(cd.install_cry_handler)
         handlers = self.psig(cd.install_cry_handler)
         self.assertIsNone(handlers['SIGUSR1']('SIGUSR1', object()))
         self.assertIsNone(handlers['SIGUSR1']('SIGUSR1', object()))
         self.assertTrue(stderr.write.called)
         self.assertTrue(stderr.write.called)
 
 
-    @disable_stdouts
+    @skip_unless_module('multiprocessing')
     def test_worker_term_handler_only_stop_MainProcess(self):
     def test_worker_term_handler_only_stop_MainProcess(self):
-        try:
-            import _multiprocessing  # noqa
-        except ImportError:
-            raise SkipTest('only relevant for multiprocessing')
         process = current_process()
         process = current_process()
         name, process.name = process.name, 'OtherProcess'
         name, process.name = process.name, 'OtherProcess'
         try:
         try:
@@ -636,13 +614,11 @@ class test_signal_handlers(WorkerAppCase):
             process.name = name
             process.name = name
             state.should_stop = None
             state.should_stop = None
 
 
-    @disable_stdouts
+    @skip_unless_symbol('os.execv')
     @patch('celery.platforms.close_open_fds')
     @patch('celery.platforms.close_open_fds')
     @patch('atexit.register')
     @patch('atexit.register')
     @patch('os.close')
     @patch('os.close')
     def test_worker_restart_handler(self, _close, register, close_open):
     def test_worker_restart_handler(self, _close, register, close_open):
-        if getattr(os, 'execv', None) is None:
-            raise SkipTest('platform does not have excv')
         argv = []
         argv = []
 
 
         def _execv(*args):
         def _execv(*args):
@@ -662,7 +638,6 @@ class test_signal_handlers(WorkerAppCase):
             os.execv = execv
             os.execv = execv
             state.should_stop = None
             state.should_stop = None
 
 
-    @disable_stdouts
     def test_worker_term_hard_handler_when_threaded(self):
     def test_worker_term_hard_handler_when_threaded(self):
         with patch('celery.apps.worker.active_thread_count') as c:
         with patch('celery.apps.worker.active_thread_count') as c:
             c.return_value = 3
             c.return_value = 3
@@ -674,7 +649,6 @@ class test_signal_handlers(WorkerAppCase):
             finally:
             finally:
                 state.should_terminate = None
                 state.should_terminate = None
 
 
-    @disable_stdouts
     def test_worker_term_hard_handler_when_single_threaded(self):
     def test_worker_term_hard_handler_when_single_threaded(self):
         with patch('celery.apps.worker.active_thread_count') as c:
         with patch('celery.apps.worker.active_thread_count') as c:
             c.return_value = 1
             c.return_value = 1

+ 33 - 682
celery/tests/case.py

@@ -1,37 +1,18 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
-try:
-    import unittest  # noqa
-    unittest.skip
-    from unittest.util import safe_repr, unorderable_list_difference
-except AttributeError:
-    import unittest2 as unittest  # noqa
-    from unittest2.util import safe_repr, unorderable_list_difference  # noqa
-
 import importlib
 import importlib
 import inspect
 import inspect
 import logging
 import logging
 import numbers
 import numbers
 import os
 import os
-import platform
-import re
 import sys
 import sys
 import threading
 import threading
-import time
-import types
-import warnings
 
 
 from contextlib import contextmanager
 from contextlib import contextmanager
 from copy import deepcopy
 from copy import deepcopy
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
 from functools import partial, wraps
 from functools import partial, wraps
-from types import ModuleType
 
 
-try:
-    from unittest import mock
-except ImportError:
-    import mock  # noqa
-from nose import SkipTest
 from kombu import Queue
 from kombu import Queue
 from kombu.utils import symbol_by_name
 from kombu.utils import symbol_by_name
 
 
@@ -39,31 +20,15 @@ from celery import Celery
 from celery.app import current_app
 from celery.app import current_app
 from celery.backends.cache import CacheBackend, DummyClient
 from celery.backends.cache import CacheBackend, DummyClient
 from celery.exceptions import CDeprecationWarning, CPendingDeprecationWarning
 from celery.exceptions import CDeprecationWarning, CPendingDeprecationWarning
-from celery.five import (
-    WhateverIO, builtins, items, reraise,
-    string_t, values, open_fqdn, module_name_t,
-)
-from celery.utils.functional import noop
 from celery.utils.imports import qualname
 from celery.utils.imports import qualname
 
 
-__all__ = [
-    'Case', 'AppCase', 'Mock', 'MagicMock', 'ANY', 'TaskMessage',
-    'patch', 'call', 'sentinel', 'skip_unless_module',
-    'wrap_logger', 'with_environ', 'sleepdeprived',
-    'skip_if_environ', 'todo', 'skip', 'skip_if',
-    'skip_unless', 'mask_modules', 'override_stdouts', 'mock_module',
-    'replace_module_value', 'sys_platform', 'reset_modules',
-    'patch_modules', 'mock_context', 'mock_open',
-    'assert_signal_called', 'skip_if_pypy',
-    'skip_if_jython', 'task_message_from_sig', 'restore_logging',
-]
-patch = mock.patch
-call = mock.call
-sentinel = mock.sentinel
-MagicMock = mock.MagicMock
-ANY = mock.ANY
+from ._case import *  # noqa
+from ._case import __all__ as _case_all, Case as _Case, decorator
 
 
-PY3 = sys.version_info[0] == 3
+__all__ = _case_all + [
+    'AppCase', 'TaskMessage', 'TaskMessage1',
+    'depends_on_current_app', 'assert_signal_called', 'task_message_from_sig',
+]
 
 
 CASE_REDEFINES_SETUP = """\
 CASE_REDEFINES_SETUP = """\
 {name} (subclass of AppCase) redefines private "setUp", should be: "setup"\
 {name} (subclass of AppCase) redefines private "setUp", should be: "setup"\
@@ -110,6 +75,11 @@ CELERY_TEST_CONFIG = {
 }
 }
 
 
 
 
+class Case(_Case):
+    DeprecationWarning = CDeprecationWarning
+    PendingDeprecationWarning = CPendingDeprecationWarning
+
+
 class Trap(object):
 class Trap(object):
 
 
     def __getattr__(self, name):
     def __getattr__(self, name):
@@ -133,299 +103,10 @@ def UnitApp(name=None, set_as_current=False, log=UnitLogging,
     return app
     return app
 
 
 
 
-class Mock(mock.Mock):
-
-    def __init__(self, *args, **kwargs):
-        attrs = kwargs.pop('attrs', None) or {}
-        super(Mock, self).__init__(*args, **kwargs)
-        for attr_name, attr_value in items(attrs):
-            setattr(self, attr_name, attr_value)
-
-
-class _ContextMock(Mock):
-    """Dummy class implementing __enter__ and __exit__
-    as the :keyword:`with` statement requires these to be implemented
-    in the class, not just the instance."""
-
-    def __enter__(self):
-        return self
-
-    def __exit__(self, *exc_info):
-        pass
-
-
-def ContextMock(*args, **kwargs):
-    obj = _ContextMock(*args, **kwargs)
-    obj.attach_mock(_ContextMock(), '__enter__')
-    obj.attach_mock(_ContextMock(), '__exit__')
-    obj.__enter__.return_value = obj
-    # if __exit__ return a value the exception is ignored,
-    # so it must return None here.
-    obj.__exit__.return_value = None
-    return obj
-
-
-def _bind(f, o):
-    @wraps(f)
-    def bound_meth(*fargs, **fkwargs):
-        return f(o, *fargs, **fkwargs)
-    return bound_meth
-
-
-if PY3:  # pragma: no cover
-    def _get_class_fun(meth):
-        return meth
-else:
-    def _get_class_fun(meth):
-        return meth.__func__
-
-
-class MockCallbacks(object):
-
-    def __new__(cls, *args, **kwargs):
-        r = Mock(name=cls.__name__)
-        _get_class_fun(cls.__init__)(r, *args, **kwargs)
-        for key, value in items(vars(cls)):
-            if key not in ('__dict__', '__weakref__', '__new__', '__init__'):
-                if inspect.ismethod(value) or inspect.isfunction(value):
-                    r.__getattr__(key).side_effect = _bind(value, r)
-                else:
-                    r.__setattr__(key, value)
-        return r
-
-
-def skip_unless_module(module):
-
-    def _inner(fun):
-
-        @wraps(fun)
-        def __inner(*args, **kwargs):
-            try:
-                importlib.import_module(module)
-            except ImportError:
-                raise SkipTest('Does not have %s' % (module,))
-
-            return fun(*args, **kwargs)
-
-        return __inner
-    return _inner
-
-
-# -- adds assertWarns from recent unittest2, not in Python 2.7.
-
-class _AssertRaisesBaseContext(object):
-
-    def __init__(self, expected, test_case, callable_obj=None,
-                 expected_regex=None):
-        self.expected = expected
-        self.failureException = test_case.failureException
-        self.obj_name = None
-        if isinstance(expected_regex, string_t):
-            expected_regex = re.compile(expected_regex)
-        self.expected_regex = expected_regex
-
-
-def _is_magic_module(m):
-    # some libraries create custom module types that are lazily
-    # lodaded, e.g. Django installs some modules in sys.modules that
-    # will load _tkinter and other shit when touched.
-
-    # pyflakes refuses to accept 'noqa' for this isinstance.
-    cls, modtype = type(m), types.ModuleType
-    try:
-        variables = vars(cls)
-    except TypeError:
-        return True
-    else:
-        return (cls is not modtype and (
-            '__getattr__' in variables or
-            '__getattribute__' in variables))
-
-
-class _AssertWarnsContext(_AssertRaisesBaseContext):
-    """A context manager used to implement TestCase.assertWarns* methods."""
-
-    def __enter__(self):
-        # The __warningregistry__'s need to be in a pristine state for tests
-        # to work properly.
-        warnings.resetwarnings()
-        for v in list(values(sys.modules)):
-            # do not evaluate Django moved modules and other lazily
-            # initialized modules.
-            if v and not _is_magic_module(v):
-                # use raw __getattribute__ to protect even better from
-                # lazily loaded modules
-                try:
-                    object.__getattribute__(v, '__warningregistry__')
-                except AttributeError:
-                    pass
-                else:
-                    object.__setattr__(v, '__warningregistry__', {})
-        self.warnings_manager = warnings.catch_warnings(record=True)
-        self.warnings = self.warnings_manager.__enter__()
-        warnings.simplefilter('always', self.expected)
-        return self
-
-    def __exit__(self, exc_type, exc_value, tb):
-        self.warnings_manager.__exit__(exc_type, exc_value, tb)
-        if exc_type is not None:
-            # let unexpected exceptions pass through
-            return
-        try:
-            exc_name = self.expected.__name__
-        except AttributeError:
-            exc_name = str(self.expected)
-        first_matching = None
-        for m in self.warnings:
-            w = m.message
-            if not isinstance(w, self.expected):
-                continue
-            if first_matching is None:
-                first_matching = w
-            if (self.expected_regex is not None and
-                    not self.expected_regex.search(str(w))):
-                continue
-            # store warning for later retrieval
-            self.warning = w
-            self.filename = m.filename
-            self.lineno = m.lineno
-            return
-        # Now we simply try to choose a helpful failure message
-        if first_matching is not None:
-            raise self.failureException(
-                '%r does not match %r' % (
-                    self.expected_regex.pattern, str(first_matching)))
-        if self.obj_name:
-            raise self.failureException(
-                '%s not triggered by %s' % (exc_name, self.obj_name))
-        else:
-            raise self.failureException('%s not triggered' % exc_name)
-
-
 def alive_threads():
 def alive_threads():
     return [thread for thread in threading.enumerate() if thread.is_alive()]
     return [thread for thread in threading.enumerate() if thread.is_alive()]
 
 
 
 
-class Case(unittest.TestCase):
-
-    def patch(self, *path, **options):
-        manager = patch('.'.join(path), **options)
-        patched = manager.start()
-        self.addCleanup(manager.stop)
-        return patched
-
-    def mock_modules(self, *mods):
-        modules = []
-        for mod in mods:
-            mod = mod.split('.')
-            modules.extend(reversed([
-                '.'.join(mod[:-i] if i else mod) for i in range(len(mod))
-            ]))
-        modules = sorted(set(modules))
-        return self.wrap_context(mock_module(*modules))
-
-    def on_nth_call_do(self, mock, side_effect, n=1):
-
-        def on_call(*args, **kwargs):
-            if mock.call_count >= n:
-                mock.side_effect = side_effect
-            return mock.return_value
-        mock.side_effect = on_call
-        return mock
-
-    def on_nth_call_return(self, mock, retval, n=1):
-
-        def on_call(*args, **kwargs):
-            if mock.call_count >= n:
-                mock.return_value = retval
-            return mock.return_value
-        mock.side_effect = on_call
-        return mock
-
-    def mask_modules(self, *modules):
-        self.wrap_context(mask_modules(*modules))
-
-    def wrap_context(self, context):
-        ret = context.__enter__()
-        self.addCleanup(partial(context.__exit__, None, None, None))
-        return ret
-
-    def mock_environ(self, env_name, env_value):
-        return self.wrap_context(mock_environ(env_name, env_value))
-
-    def assertWarns(self, expected_warning):
-        return _AssertWarnsContext(expected_warning, self, None)
-
-    def assertWarnsRegex(self, expected_warning, expected_regex):
-        return _AssertWarnsContext(expected_warning, self,
-                                   None, expected_regex)
-
-    @contextmanager
-    def assertDeprecated(self):
-        with self.assertWarnsRegex(CDeprecationWarning,
-                                   r'scheduled for removal'):
-            yield
-
-    @contextmanager
-    def assertPendingDeprecation(self):
-        with self.assertWarnsRegex(CPendingDeprecationWarning,
-                                   r'scheduled for deprecation'):
-            yield
-
-    def assertDictContainsSubset(self, expected, actual, msg=None):
-        missing, mismatched = [], []
-
-        for key, value in items(expected):
-            if key not in actual:
-                missing.append(key)
-            elif value != actual[key]:
-                mismatched.append('%s, expected: %s, actual: %s' % (
-                    safe_repr(key), safe_repr(value),
-                    safe_repr(actual[key])))
-
-        if not (missing or mismatched):
-            return
-
-        standard_msg = ''
-        if missing:
-            standard_msg = 'Missing: %s' % ','.join(map(safe_repr, missing))
-
-        if mismatched:
-            if standard_msg:
-                standard_msg += '; '
-            standard_msg += 'Mismatched values: %s' % (
-                ','.join(mismatched))
-
-        self.fail(self._formatMessage(msg, standard_msg))
-
-    def assertItemsEqual(self, expected_seq, actual_seq, msg=None):
-        missing = unexpected = None
-        try:
-            expected = sorted(expected_seq)
-            actual = sorted(actual_seq)
-        except TypeError:
-            # Unsortable items (example: set(), complex(), ...)
-            expected = list(expected_seq)
-            actual = list(actual_seq)
-            missing, unexpected = unorderable_list_difference(
-                expected, actual)
-        else:
-            return self.assertSequenceEqual(expected, actual, msg=msg)
-
-        errors = []
-        if missing:
-            errors.append(
-                'Expected, but missing:\n    %s' % (safe_repr(missing),)
-            )
-        if unexpected:
-            errors.append(
-                'Unexpected, but present:\n    %s' % (safe_repr(unexpected),)
-            )
-        if errors:
-            standardMsg = '\n'.join(errors)
-            self.fail(self._formatMessage(msg, standardMsg))
-
-
 def depends_on_current_app(fun):
 def depends_on_current_app(fun):
     if inspect.isclass(fun):
     if inspect.isclass(fun):
         fun.contained = False
         fun.contained = False
@@ -443,11 +124,13 @@ class AppCase(Case):
 
 
     def __init__(self, *args, **kwargs):
     def __init__(self, *args, **kwargs):
         super(AppCase, self).__init__(*args, **kwargs)
         super(AppCase, self).__init__(*args, **kwargs)
-        if self.__class__.__dict__.get('setUp'):
+        setUp = self.__class__.__dict__.get('setUp')
+        tearDown = self.__class__.__dict__.get('tearDown')
+        if setUp and not hasattr(setUp, '__wrapped__'):
             raise RuntimeError(
             raise RuntimeError(
                 CASE_REDEFINES_SETUP.format(name=qualname(self)),
                 CASE_REDEFINES_SETUP.format(name=qualname(self)),
             )
             )
-        if self.__class__.__dict__.get('tearDown'):
+        if tearDown and not hasattr(tearDown, '__wrapped__'):
             raise RuntimeError(
             raise RuntimeError(
                 CASE_REDEFINES_TEARDOWN.format(name=qualname(self)),
                 CASE_REDEFINES_TEARDOWN.format(name=qualname(self)),
             )
             )
@@ -552,6 +235,9 @@ class AppCase(Case):
         if root.handlers != self.__roothandlers:
         if root.handlers != self.__roothandlers:
             raise RuntimeError(CASE_LOG_HANDLER_EFFECT.format(this))
             raise RuntimeError(CASE_LOG_HANDLER_EFFECT.format(this))
 
 
+    def assert_signal_called(self, signal, **expected):
+        return assert_signal_called(signal, **expected)
+
     def setup(self):
     def setup(self):
         pass
         pass
 
 
@@ -559,324 +245,7 @@ class AppCase(Case):
         pass
         pass
 
 
 
 
-def get_handlers(logger):
-    return [
-        h for h in logger.handlers
-        if not isinstance(h, logging.NullHandler)
-    ]
-
-
-@contextmanager
-def wrap_logger(logger, loglevel=logging.ERROR):
-    old_handlers = get_handlers(logger)
-    sio = WhateverIO()
-    siohandler = logging.StreamHandler(sio)
-    logger.handlers = [siohandler]
-
-    try:
-        yield sio
-    finally:
-        logger.handlers = old_handlers
-
-
-@contextmanager
-def mock_environ(env_name, env_value):
-    sentinel = object()
-    prev_val = os.environ.get(env_name, sentinel)
-    os.environ[env_name] = env_value
-    try:
-        yield env_value
-    finally:
-        if prev_val is sentinel:
-            os.environ.pop(env_name, None)
-        else:
-            os.environ[env_name] = prev_val
-
-
-def with_environ(env_name, env_value):
-
-    def _envpatched(fun):
-
-        @wraps(fun)
-        def _patch_environ(*args, **kwargs):
-            with mock_environ(env_name, env_value):
-                return fun(*args, **kwargs)
-        return _patch_environ
-    return _envpatched
-
-
-def sleepdeprived(module=time):
-
-    def _sleepdeprived(fun):
-
-        @wraps(fun)
-        def __sleepdeprived(*args, **kwargs):
-            old_sleep = module.sleep
-            module.sleep = noop
-            try:
-                return fun(*args, **kwargs)
-            finally:
-                module.sleep = old_sleep
-
-        return __sleepdeprived
-
-    return _sleepdeprived
-
-
-def skip_if_environ(env_var_name):
-
-    def _wrap_test(fun):
-
-        @wraps(fun)
-        def _skips_if_environ(*args, **kwargs):
-            if os.environ.get(env_var_name):
-                raise SkipTest('SKIP %s: %s set\n' % (
-                    fun.__name__, env_var_name))
-            return fun(*args, **kwargs)
-
-        return _skips_if_environ
-
-    return _wrap_test
-
-
-def _skip_test(reason, sign):
-
-    def _wrap_test(fun):
-
-        @wraps(fun)
-        def _skipped_test(*args, **kwargs):
-            raise SkipTest('%s: %s' % (sign, reason))
-
-        return _skipped_test
-    return _wrap_test
-
-
-def todo(reason):
-    """TODO test decorator."""
-    return _skip_test(reason, 'TODO')
-
-
-def skip(reason):
-    """Skip test decorator."""
-    return _skip_test(reason, 'SKIP')
-
-
-def skip_if(predicate, reason):
-    """Skip test if predicate is :const:`True`."""
-
-    def _inner(fun):
-        return predicate and skip(reason)(fun) or fun
-
-    return _inner
-
-
-def skip_unless(predicate, reason):
-    """Skip test if predicate is :const:`False`."""
-    return skip_if(not predicate, reason)
-
-
-# Taken from
-# http://bitbucket.org/runeh/snippets/src/tip/missing_modules.py
-@contextmanager
-def mask_modules(*modnames):
-    """Ban some modules from being importable inside the context
-
-    For example:
-
-        >>> with mask_modules('sys'):
-        ...     try:
-        ...         import sys
-        ...     except ImportError:
-        ...         print('sys not found')
-        sys not found
-
-        >>> import sys  # noqa
-        >>> sys.version
-        (2, 5, 2, 'final', 0)
-
-    """
-
-    realimport = builtins.__import__
-
-    def myimp(name, *args, **kwargs):
-        if name in modnames:
-            raise ImportError('No module named %s' % name)
-        else:
-            return realimport(name, *args, **kwargs)
-
-    builtins.__import__ = myimp
-    try:
-        yield True
-    finally:
-        builtins.__import__ = realimport
-
-
-@contextmanager
-def override_stdouts():
-    """Override `sys.stdout` and `sys.stderr` with `WhateverIO`."""
-    prev_out, prev_err = sys.stdout, sys.stderr
-    prev_rout, prev_rerr = sys.__stdout__, sys.__stderr__
-    mystdout, mystderr = WhateverIO(), WhateverIO()
-    sys.stdout = sys.__stdout__ = mystdout
-    sys.stderr = sys.__stderr__ = mystderr
-
-    try:
-        yield mystdout, mystderr
-    finally:
-        sys.stdout = prev_out
-        sys.stderr = prev_err
-        sys.__stdout__ = prev_rout
-        sys.__stderr__ = prev_rerr
-
-
-def disable_stdouts(fun):
-
-    @wraps(fun)
-    def disable(*args, **kwargs):
-        with override_stdouts():
-            return fun(*args, **kwargs)
-    return disable
-
-
-def _old_patch(module, name, mocked):
-    module = importlib.import_module(module)
-
-    def _patch(fun):
-
-        @wraps(fun)
-        def __patched(*args, **kwargs):
-            prev = getattr(module, name)
-            setattr(module, name, mocked)
-            try:
-                return fun(*args, **kwargs)
-            finally:
-                setattr(module, name, prev)
-        return __patched
-    return _patch
-
-
-@contextmanager
-def replace_module_value(module, name, value=None):
-    has_prev = hasattr(module, name)
-    prev = getattr(module, name, None)
-    if value:
-        setattr(module, name, value)
-    else:
-        try:
-            delattr(module, name)
-        except AttributeError:
-            pass
-    try:
-        yield
-    finally:
-        if prev is not None:
-            setattr(module, name, prev)
-        if not has_prev:
-            try:
-                delattr(module, name)
-            except AttributeError:
-                pass
-pypy_version = partial(
-    replace_module_value, sys, 'pypy_version_info',
-)
-platform_pyimp = partial(
-    replace_module_value, platform, 'python_implementation',
-)
-
-
-@contextmanager
-def sys_platform(value):
-    prev, sys.platform = sys.platform, value
-    try:
-        yield
-    finally:
-        sys.platform = prev
-
-
-@contextmanager
-def reset_modules(*modules):
-    prev = {k: sys.modules.pop(k) for k in modules if k in sys.modules}
-    try:
-        yield
-    finally:
-        sys.modules.update(prev)
-
-
-@contextmanager
-def patch_modules(*modules):
-    prev = {}
-    for mod in modules:
-        prev[mod] = sys.modules.get(mod)
-        sys.modules[mod] = ModuleType(module_name_t(mod))
-    try:
-        yield
-    finally:
-        for name, mod in items(prev):
-            if mod is None:
-                sys.modules.pop(name, None)
-            else:
-                sys.modules[name] = mod
-
-
-@contextmanager
-def mock_module(*names):
-    prev = {}
-
-    class MockModule(ModuleType):
-
-        def __getattr__(self, attr):
-            setattr(self, attr, Mock())
-            return ModuleType.__getattribute__(self, attr)
-
-    mods = []
-    for name in names:
-        try:
-            prev[name] = sys.modules[name]
-        except KeyError:
-            pass
-        mod = sys.modules[name] = MockModule(module_name_t(name))
-        mods.append(mod)
-    try:
-        yield mods
-    finally:
-        for name in names:
-            try:
-                sys.modules[name] = prev[name]
-            except KeyError:
-                try:
-                    del(sys.modules[name])
-                except KeyError:
-                    pass
-
-
-@contextmanager
-def mock_context(mock, typ=Mock):
-    context = mock.return_value = Mock()
-    context.__enter__ = typ()
-    context.__exit__ = typ()
-
-    def on_exit(*x):
-        if x[0]:
-            reraise(x[0], x[1], x[2])
-    context.__exit__.side_effect = on_exit
-    context.__enter__.return_value = context
-    try:
-        yield context
-    finally:
-        context.reset()
-
-
-@contextmanager
-def mock_open(typ=WhateverIO, side_effect=None):
-    with patch(open_fqdn) as open_:
-        with mock_context(open_) as context:
-            if side_effect is not None:
-                context.__enter__.side_effect = side_effect
-            val = context.__enter__.return_value = typ()
-            val.__exit__ = Mock()
-            yield val
-
-
+@decorator
 @contextmanager
 @contextmanager
 def assert_signal_called(signal, **expected):
 def assert_signal_called(signal, **expected):
     handler = Mock()
     handler = Mock()
@@ -889,26 +258,6 @@ def assert_signal_called(signal, **expected):
     handler.assert_called_with(signal=signal, **expected)
     handler.assert_called_with(signal=signal, **expected)
 
 
 
 
-def skip_if_pypy(fun):
-
-    @wraps(fun)
-    def _inner(*args, **kwargs):
-        if getattr(sys, 'pypy_version_info', None):
-            raise SkipTest('does not work on PyPy')
-        return fun(*args, **kwargs)
-    return _inner
-
-
-def skip_if_jython(fun):
-
-    @wraps(fun)
-    def _inner(*args, **kwargs):
-        if sys.platform.startswith('java'):
-            raise SkipTest('does not work on Jython')
-        return fun(*args, **kwargs)
-    return _inner
-
-
 def TaskMessage(name, id=None, args=(), kwargs={}, callbacks=None,
 def TaskMessage(name, id=None, args=(), kwargs={}, callbacks=None,
                 errbacks=None, chain=None, shadow=None, utc=None, **options):
                 errbacks=None, chain=None, shadow=None, utc=None, **options):
     from celery import uuid
     from celery import uuid
@@ -979,16 +328,18 @@ def task_message_from_sig(app, sig, utc=True, TaskMessage=TaskMessage):
     )
     )
 
 
 
 
-@contextmanager
-def restore_logging():
-    outs = sys.stdout, sys.stderr, sys.__stdout__, sys.__stderr__
-    root = logging.getLogger()
-    level = root.level
-    handlers = root.handlers
+def _old_patch(module, name, mocked):
+    module = importlib.import_module(module)
 
 
-    try:
-        yield
-    finally:
-        sys.stdout, sys.stderr, sys.__stdout__, sys.__stderr__ = outs
-        root.level = level
-        root.handlers[:] = handlers
+    def _patch(fun):
+
+        @wraps(fun)
+        def __patched(*args, **kwargs):
+            prev = getattr(module, name)
+            setattr(module, name, mocked)
+            try:
+                return fun(*args, **kwargs)
+            finally:
+                setattr(module, name, prev)
+        return __patched
+    return _patch

+ 1 - 2
celery/tests/concurrency/test_eventlet.py

@@ -12,13 +12,12 @@ from celery.concurrency.eventlet import (
 from celery.tests.case import AppCase, Mock, patch, skip_if_pypy
 from celery.tests.case import AppCase, Mock, patch, skip_if_pypy
 
 
 
 
+@skip_if_pypy()
 class EventletCase(AppCase):
 class EventletCase(AppCase):
 
 
-    @skip_if_pypy
     def setup(self):
     def setup(self):
         self.mock_modules(*eventlet_modules)
         self.mock_modules(*eventlet_modules)
 
 
-    @skip_if_pypy
     def teardown(self):
     def teardown(self):
         for mod in [mod for mod in sys.modules if mod.startswith('eventlet')]:
         for mod in [mod for mod in sys.modules if mod.startswith('eventlet')]:
             try:
             try:

+ 1 - 1
celery/tests/concurrency/test_gevent.py

@@ -17,9 +17,9 @@ gevent_modules = (
 )
 )
 
 
 
 
+@skip_if_pypy()
 class GeventCase(AppCase):
 class GeventCase(AppCase):
 
 
-    @skip_if_pypy
     def setup(self):
     def setup(self):
         self.mock_modules(*gevent_modules)
         self.mock_modules(*gevent_modules)
 
 

+ 2 - 5
celery/tests/concurrency/test_pool.py

@@ -5,7 +5,7 @@ import itertools
 
 
 from billiard.einfo import ExceptionInfo
 from billiard.einfo import ExceptionInfo
 
 
-from celery.tests.case import AppCase, SkipTest
+from celery.tests.case import AppCase, skip_unless_module
 
 
 
 
 def do_something(i):
 def do_something(i):
@@ -23,13 +23,10 @@ def raise_something(i):
         return ExceptionInfo()
         return ExceptionInfo()
 
 
 
 
+@skip_unless_module('multiprocessing')
 class test_TaskPool(AppCase):
 class test_TaskPool(AppCase):
 
 
     def setup(self):
     def setup(self):
-        try:
-            __import__('multiprocessing')
-        except ImportError:
-            raise SkipTest('multiprocessing not supported')
         from celery.concurrency.prefork import TaskPool
         from celery.concurrency.prefork import TaskPool
         self.TaskPool = TaskPool
         self.TaskPool = TaskPool
 
 

+ 7 - 16
celery/tests/concurrency/test_prefork.py

@@ -3,7 +3,6 @@ from __future__ import absolute_import, unicode_literals
 import errno
 import errno
 import os
 import os
 import socket
 import socket
-import sys
 
 
 from itertools import cycle
 from itertools import cycle
 
 
@@ -13,7 +12,9 @@ from celery.five import range
 from celery.utils.functional import noop
 from celery.utils.functional import noop
 from celery.utils.objects import Bunch
 from celery.utils.objects import Bunch
 
 
-from celery.tests.case import AppCase, Mock, SkipTest, patch, restore_logging
+from celery.tests.case import (
+    AppCase, Mock, patch, restore_logging, skip_if_win32, skip_unless_module,
+)
 
 
 try:
 try:
     from celery.concurrency import prefork as mp
     from celery.concurrency import prefork as mp
@@ -185,21 +186,14 @@ class ExeMockTaskPool(mp.TaskPool):
     Pool = BlockingPool = ExeMockPool
     Pool = BlockingPool = ExeMockPool
 
 
 
 
+@skip_unless_module('multiprocessing')
 class PoolCase(AppCase):
 class PoolCase(AppCase):
-
-    def setup(self):
-        try:
-            import multiprocessing  # noqa
-        except ImportError:
-            raise SkipTest('multiprocessing not supported')
+    pass
 
 
 
 
+@skip_if_win32
 class test_AsynPool(PoolCase):
 class test_AsynPool(PoolCase):
 
 
-    def setup(self):
-        if sys.platform == 'win32':
-            raise SkipTest('win32: skip')
-
     def test_gen_not_started(self):
     def test_gen_not_started(self):
 
 
         def gen():
         def gen():
@@ -303,12 +297,9 @@ class test_AsynPool(PoolCase):
         w.outq.put.assert_called_with((asynpool.WORKER_UP, (1234,)))
         w.outq.put.assert_called_with((asynpool.WORKER_UP, (1234,)))
 
 
 
 
+@skip_if_win32
 class test_ResultHandler(PoolCase):
 class test_ResultHandler(PoolCase):
 
 
-    def setup(self):
-        if sys.platform == 'win32':
-            raise SkipTest('win32: skip')
-
     def test_process_result(self):
     def test_process_result(self):
         x = asynpool.ResultHandler(
         x = asynpool.ResultHandler(
             Mock(), Mock(), {}, Mock(),
             Mock(), Mock(), {}, Mock(),

+ 4 - 3
celery/tests/contrib/test_rdb.py

@@ -8,7 +8,8 @@ from celery.contrib.rdb import (
     debugger,
     debugger,
     set_trace,
     set_trace,
 )
 )
-from celery.tests.case import AppCase, Mock, WhateverIO, patch, skip_if_pypy
+from celery.five import WhateverIO
+from celery.tests.case import AppCase, Mock, patch, skip_if_pypy
 
 
 
 
 class SockErr(socket.error):
 class SockErr(socket.error):
@@ -31,7 +32,7 @@ class test_Rdb(AppCase):
         self.assertTrue(debugger.return_value.set_trace.called)
         self.assertTrue(debugger.return_value.set_trace.called)
 
 
     @patch('celery.contrib.rdb.Rdb.get_avail_port')
     @patch('celery.contrib.rdb.Rdb.get_avail_port')
-    @skip_if_pypy
+    @skip_if_pypy()
     def test_rdb(self, get_avail_port):
     def test_rdb(self, get_avail_port):
         sock = Mock()
         sock = Mock()
         get_avail_port.return_value = (sock, 8000)
         get_avail_port.return_value = (sock, 8000)
@@ -75,7 +76,7 @@ class test_Rdb(AppCase):
             rdb.set_quit.assert_called_with()
             rdb.set_quit.assert_called_with()
 
 
     @patch('socket.socket')
     @patch('socket.socket')
-    @skip_if_pypy
+    @skip_if_pypy()
     def test_get_avail_port(self, sock):
     def test_get_avail_port(self, sock):
         out = WhateverIO()
         out = WhateverIO()
         sock.return_value.accept.return_value = (Mock(), ['helu'])
         sock.return_value.accept.return_value = (Mock(), ['helu'])

+ 2 - 6
celery/tests/events/test_cursesmon.py

@@ -1,6 +1,6 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
-from celery.tests.case import AppCase, SkipTest
+from celery.tests.case import AppCase, skip_unless_module
 
 
 
 
 class MockWindow(object):
 class MockWindow(object):
@@ -9,14 +9,10 @@ class MockWindow(object):
         return self.y, self.x
         return self.y, self.x
 
 
 
 
+@skip_unless_module('curses')
 class test_CursesDisplay(AppCase):
 class test_CursesDisplay(AppCase):
 
 
     def setup(self):
     def setup(self):
-        try:
-            import curses  # noqa
-        except (ImportError, OSError):
-            raise SkipTest('curses monitor requires curses')
-
         from celery.events import cursesmon
         from celery.events import cursesmon
         self.monitor = cursesmon.CursesMonitor(object(), app=self.app)
         self.monitor = cursesmon.CursesMonitor(object(), app=self.app)
         self.win = MockWindow()
         self.win = MockWindow()

+ 2 - 2
celery/tests/events/test_state.py

@@ -19,7 +19,7 @@ from celery.events.state import (
 )
 )
 from celery.five import range
 from celery.five import range
 from celery.utils import uuid
 from celery.utils import uuid
-from celery.tests.case import AppCase, Mock, SkipTest, patch
+from celery.tests.case import AppCase, Mock, patch, todo
 
 
 try:
 try:
     Decimal(2.6)
     Decimal(2.6)
@@ -374,8 +374,8 @@ class test_State(AppCase):
         self.assertEqual(now[1][0], tC)
         self.assertEqual(now[1][0], tC)
         self.assertEqual(now[2][0], tB)
         self.assertEqual(now[2][0], tB)
 
 
+    @todo(reason='not working')
     def test_task_descending_clock_ordering(self):
     def test_task_descending_clock_ordering(self):
-        raise SkipTest('not working')
         state = State()
         state = State()
         r = ev_logical_clock_ordering(state)
         r = ev_logical_clock_ordering(state)
         tA, tB, tC = r.uids
         tA, tB, tC = r.uids

+ 3 - 7
celery/tests/security/case.py

@@ -1,12 +1,8 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
-from celery.tests.case import AppCase, SkipTest
+from celery.tests.case import AppCase, skip_unless_module
 
 
 
 
+@skip_unless_module('OpenSSL.crypto', name='pyOpenSSL')
 class SecurityCase(AppCase):
 class SecurityCase(AppCase):
-
-    def setup(self):
-        try:
-            from OpenSSL import crypto  # noqa
-        except ImportError:
-            raise SkipTest('OpenSSL.crypto not installed')
+    pass

+ 2 - 2
celery/tests/security/test_certificate.py

@@ -6,7 +6,7 @@ from celery.security.certificate import Certificate, CertStore, FSCertStore
 from . import CERT1, CERT2, KEY1
 from . import CERT1, CERT2, KEY1
 from .case import SecurityCase
 from .case import SecurityCase
 
 
-from celery.tests.case import Mock, SkipTest, mock_open, patch
+from celery.tests.case import Mock, mock_open, patch, todo
 
 
 
 
 class test_Certificate(SecurityCase):
 class test_Certificate(SecurityCase):
@@ -27,8 +27,8 @@ class test_Certificate(SecurityCase):
         with self.assertRaises(SecurityError):
         with self.assertRaises(SecurityError):
             Certificate(KEY1)
             Certificate(KEY1)
 
 
+    @todo(reason='cert expired')
     def test_has_expired(self):
     def test_has_expired(self):
-        raise SkipTest('cert expired')
         self.assertFalse(Certificate(CERT1).has_expired())
         self.assertFalse(Certificate(CERT1).has_expired())
 
 
     def test_has_expired_mock(self):
     def test_has_expired_mock(self):

+ 3 - 11
celery/tests/utils/test_datastructures.py

@@ -1,7 +1,6 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
 import pickle
 import pickle
-import sys
 
 
 from collections import Mapping
 from collections import Mapping
 from itertools import count
 from itertools import count
@@ -16,10 +15,10 @@ from celery.datastructures import (
     ConfigurationView,
     ConfigurationView,
     DependencyGraph,
     DependencyGraph,
 )
 )
-from celery.five import items
+from celery.five import WhateverIO, items
 from celery.utils.objects import Bunch
 from celery.utils.objects import Bunch
 
 
-from celery.tests.case import Case, Mock, WhateverIO, SkipTest
+from celery.tests.case import Case, Mock, skip_if_win32
 
 
 
 
 class test_DictAttribute(Case):
 class test_DictAttribute(Case):
@@ -168,15 +167,10 @@ class test_ExceptionInfo(Case):
             self.assertTrue(r)
             self.assertTrue(r)
 
 
 
 
+@skip_if_win32()
 class test_LimitedSet(Case):
 class test_LimitedSet(Case):
 
 
-    def setUp(self):
-        if sys.platform == 'win32':
-            raise SkipTest('Not working on Windows')
-
     def test_add(self):
     def test_add(self):
-        if sys.platform == 'win32':
-            raise SkipTest('Not working properly on Windows')
         s = LimitedSet(maxlen=2)
         s = LimitedSet(maxlen=2)
         s.add('foo')
         s.add('foo')
         s.add('bar')
         s.add('bar')
@@ -239,8 +233,6 @@ class test_LimitedSet(Case):
         self.assertEqual(pickle.loads(pickle.dumps(s)), s)
         self.assertEqual(pickle.loads(pickle.dumps(s)), s)
 
 
     def test_iter(self):
     def test_iter(self):
-        if sys.platform == 'win32':
-            raise SkipTest('Not working on Windows')
         s = LimitedSet(maxlen=3)
         s = LimitedSet(maxlen=3)
         items = ['foo', 'bar', 'baz', 'xaz']
         items = ['foo', 'bar', 'baz', 'xaz']
         for item in items:
         for item in items:

+ 571 - 561
celery/tests/utils/test_platforms.py

@@ -8,7 +8,7 @@ import tempfile
 
 
 from celery import _find_option_with_arg
 from celery import _find_option_with_arg
 from celery import platforms
 from celery import platforms
-from celery.five import open_fqdn
+from celery.five import WhateverIO
 from celery.platforms import (
 from celery.platforms import (
     get_fdmax,
     get_fdmax,
     ignore_errno,
     ignore_errno,
@@ -38,9 +38,10 @@ try:
 except ImportError:  # pragma: no cover
 except ImportError:  # pragma: no cover
     resource = None  # noqa
     resource = None  # noqa
 
 
+from celery.tests._case import open_fqdn
 from celery.tests.case import (
 from celery.tests.case import (
-    Case, WhateverIO, Mock, SkipTest,
-    call, override_stdouts, mock_open, patch,
+    Case, Mock,
+    call, override_stdouts, mock_open, patch, skip_if_win32,
 )
 )
 
 
 
 
@@ -59,12 +60,9 @@ class test_find_option_with_arg(Case):
         )
         )
 
 
 
 
+@skip_if_win32()
 class test_fd_by_path(Case):
 class test_fd_by_path(Case):
 
 
-    def setUp(self):
-        if sys.platform == 'win32':
-            raise SkipTest('win32: skip')
-
     def test_finds(self):
     def test_finds(self):
         test_file = tempfile.NamedTemporaryFile()
         test_file = tempfile.NamedTemporaryFile()
         try:
         try:
@@ -143,9 +141,8 @@ class test_Signals(Case):
         self.assertTrue(signals.supported('INT'))
         self.assertTrue(signals.supported('INT'))
         self.assertFalse(signals.supported('SIGIMAGINARY'))
         self.assertFalse(signals.supported('SIGIMAGINARY'))
 
 
+    @skip_if_win32()
     def test_reset_alarm(self):
     def test_reset_alarm(self):
-        if sys.platform == 'win32':
-            raise SkipTest('signal.alarm not available on Windows')
         with patch('signal.alarm') as _alarm:
         with patch('signal.alarm') as _alarm:
             signals.reset_alarm()
             signals.reset_alarm()
             _alarm.assert_called_with(0)
             _alarm.assert_called_with(0)
@@ -189,622 +186,635 @@ class test_Signals(Case):
         signals['INT'] = lambda *a: a
         signals['INT'] = lambda *a: a
 
 
 
 
-if not platforms.IS_WINDOWS:
-
-    class test_get_fdmax(Case):
-
-        @patch('resource.getrlimit')
-        def test_when_infinity(self, getrlimit):
-            with patch('os.sysconf') as sysconfig:
-                sysconfig.side_effect = KeyError()
-                getrlimit.return_value = [None, resource.RLIM_INFINITY]
-                default = object()
-                self.assertIs(get_fdmax(default), default)
-
-        @patch('resource.getrlimit')
-        def test_when_actual(self, getrlimit):
-            with patch('os.sysconf') as sysconfig:
-                sysconfig.side_effect = KeyError()
-                getrlimit.return_value = [None, 13]
-                self.assertEqual(get_fdmax(None), 13)
-
-    class test_maybe_drop_privileges(Case):
-
-        def test_on_windows(self):
-            prev, sys.platform = sys.platform, 'win32'
-            try:
-                maybe_drop_privileges()
-            finally:
-                sys.platform = prev
-
-        @patch('os.getegid')
-        @patch('os.getgid')
-        @patch('os.geteuid')
-        @patch('os.getuid')
-        @patch('celery.platforms.parse_uid')
-        @patch('celery.platforms.parse_gid')
-        @patch('pwd.getpwuid')
-        @patch('celery.platforms.setgid')
-        @patch('celery.platforms.setuid')
-        @patch('celery.platforms.initgroups')
-        def test_with_uid(self, initgroups, setuid, setgid,
-                          getpwuid, parse_gid, parse_uid, getuid, geteuid,
-                          getgid, getegid):
-            geteuid.return_value = 10
-            getuid.return_value = 10
-
-            class pw_struct(object):
-                pw_gid = 50001
-
-            def raise_on_second_call(*args, **kwargs):
-                setuid.side_effect = OSError()
-                setuid.side_effect.errno = errno.EPERM
-            setuid.side_effect = raise_on_second_call
-            getpwuid.return_value = pw_struct()
-            parse_uid.return_value = 5001
-            parse_gid.return_value = 5001
-            maybe_drop_privileges(uid='user')
-            parse_uid.assert_called_with('user')
-            getpwuid.assert_called_with(5001)
-            setgid.assert_called_with(50001)
-            initgroups.assert_called_with(5001, 50001)
-            setuid.assert_has_calls([call(5001), call(0)])
-
-            setuid.side_effect = raise_on_second_call
-
-            def to_root_on_second_call(mock, first):
-                return_value = [first]
-
-                def on_first_call(*args, **kwargs):
-                    ret, return_value[0] = return_value[0], 0
-                    return ret
-                mock.side_effect = on_first_call
-            to_root_on_second_call(geteuid, 10)
-            to_root_on_second_call(getuid, 10)
-            with self.assertRaises(AssertionError):
-                maybe_drop_privileges(uid='user')
+@skip_if_win32()
+class test_get_fdmax(Case):
 
 
-            getuid.return_value = getuid.side_effect = None
-            geteuid.return_value = geteuid.side_effect = None
-            getegid.return_value = 0
-            getgid.return_value = 0
-            setuid.side_effect = raise_on_second_call
-            with self.assertRaises(AssertionError):
-                maybe_drop_privileges(gid='group')
-
-            getuid.reset_mock()
-            geteuid.reset_mock()
-            setuid.reset_mock()
-            getuid.side_effect = geteuid.side_effect = None
-
-            def raise_on_second_call(*args, **kwargs):
-                setuid.side_effect = OSError()
-                setuid.side_effect.errno = errno.ENOENT
-            setuid.side_effect = raise_on_second_call
-            with self.assertRaises(OSError):
-                maybe_drop_privileges(uid='user')
-
-        @patch('celery.platforms.parse_uid')
-        @patch('celery.platforms.parse_gid')
-        @patch('celery.platforms.setgid')
-        @patch('celery.platforms.setuid')
-        @patch('celery.platforms.initgroups')
-        def test_with_guid(self, initgroups, setuid, setgid,
-                           parse_gid, parse_uid):
-
-            def raise_on_second_call(*args, **kwargs):
-                setuid.side_effect = OSError()
-                setuid.side_effect.errno = errno.EPERM
-            setuid.side_effect = raise_on_second_call
-            parse_uid.return_value = 5001
-            parse_gid.return_value = 50001
-            maybe_drop_privileges(uid='user', gid='group')
-            parse_uid.assert_called_with('user')
-            parse_gid.assert_called_with('group')
-            setgid.assert_called_with(50001)
-            initgroups.assert_called_with(5001, 50001)
-            setuid.assert_has_calls([call(5001), call(0)])
+    @patch('resource.getrlimit')
+    def test_when_infinity(self, getrlimit):
+        with patch('os.sysconf') as sysconfig:
+            sysconfig.side_effect = KeyError()
+            getrlimit.return_value = [None, resource.RLIM_INFINITY]
+            default = object()
+            self.assertIs(get_fdmax(default), default)
 
 
-            setuid.side_effect = None
-            with self.assertRaises(RuntimeError):
-                maybe_drop_privileges(uid='user', gid='group')
+    @patch('resource.getrlimit')
+    def test_when_actual(self, getrlimit):
+        with patch('os.sysconf') as sysconfig:
+            sysconfig.side_effect = KeyError()
+            getrlimit.return_value = [None, 13]
+            self.assertEqual(get_fdmax(None), 13)
+
+
+@skip_if_win32()
+class test_maybe_drop_privileges(Case):
+
+    def test_on_windows(self):
+        prev, sys.platform = sys.platform, 'win32'
+        try:
+            maybe_drop_privileges()
+        finally:
+            sys.platform = prev
+
+    @patch('os.getegid')
+    @patch('os.getgid')
+    @patch('os.geteuid')
+    @patch('os.getuid')
+    @patch('celery.platforms.parse_uid')
+    @patch('celery.platforms.parse_gid')
+    @patch('pwd.getpwuid')
+    @patch('celery.platforms.setgid')
+    @patch('celery.platforms.setuid')
+    @patch('celery.platforms.initgroups')
+    def test_with_uid(self, initgroups, setuid, setgid,
+                      getpwuid, parse_gid, parse_uid, getuid, geteuid,
+                      getgid, getegid):
+        geteuid.return_value = 10
+        getuid.return_value = 10
+
+        class pw_struct(object):
+            pw_gid = 50001
+
+        def raise_on_second_call(*args, **kwargs):
             setuid.side_effect = OSError()
             setuid.side_effect = OSError()
-            setuid.side_effect.errno = errno.EINVAL
-            with self.assertRaises(OSError):
-                maybe_drop_privileges(uid='user', gid='group')
-
-        @patch('celery.platforms.setuid')
-        @patch('celery.platforms.setgid')
-        @patch('celery.platforms.parse_gid')
-        def test_only_gid(self, parse_gid, setgid, setuid):
-            parse_gid.return_value = 50001
+            setuid.side_effect.errno = errno.EPERM
+        setuid.side_effect = raise_on_second_call
+        getpwuid.return_value = pw_struct()
+        parse_uid.return_value = 5001
+        parse_gid.return_value = 5001
+        maybe_drop_privileges(uid='user')
+        parse_uid.assert_called_with('user')
+        getpwuid.assert_called_with(5001)
+        setgid.assert_called_with(50001)
+        initgroups.assert_called_with(5001, 50001)
+        setuid.assert_has_calls([call(5001), call(0)])
+
+        setuid.side_effect = raise_on_second_call
+
+        def to_root_on_second_call(mock, first):
+            return_value = [first]
+
+            def on_first_call(*args, **kwargs):
+                ret, return_value[0] = return_value[0], 0
+                return ret
+            mock.side_effect = on_first_call
+        to_root_on_second_call(geteuid, 10)
+        to_root_on_second_call(getuid, 10)
+        with self.assertRaises(AssertionError):
+            maybe_drop_privileges(uid='user')
+
+        getuid.return_value = getuid.side_effect = None
+        geteuid.return_value = geteuid.side_effect = None
+        getegid.return_value = 0
+        getgid.return_value = 0
+        setuid.side_effect = raise_on_second_call
+        with self.assertRaises(AssertionError):
             maybe_drop_privileges(gid='group')
             maybe_drop_privileges(gid='group')
-            parse_gid.assert_called_with('group')
-            setgid.assert_called_with(50001)
-            self.assertFalse(setuid.called)
 
 
-    class test_setget_uid_gid(Case):
+        getuid.reset_mock()
+        geteuid.reset_mock()
+        setuid.reset_mock()
+        getuid.side_effect = geteuid.side_effect = None
 
 
-        @patch('celery.platforms.parse_uid')
-        @patch('os.setuid')
-        def test_setuid(self, _setuid, parse_uid):
-            parse_uid.return_value = 5001
-            setuid('user')
-            parse_uid.assert_called_with('user')
-            _setuid.assert_called_with(5001)
+        def raise_on_second_call(*args, **kwargs):
+            setuid.side_effect = OSError()
+            setuid.side_effect.errno = errno.ENOENT
+        setuid.side_effect = raise_on_second_call
+        with self.assertRaises(OSError):
+            maybe_drop_privileges(uid='user')
 
 
-        @patch('celery.platforms.parse_gid')
-        @patch('os.setgid')
-        def test_setgid(self, _setgid, parse_gid):
-            parse_gid.return_value = 50001
-            setgid('group')
-            parse_gid.assert_called_with('group')
-            _setgid.assert_called_with(50001)
+    @patch('celery.platforms.parse_uid')
+    @patch('celery.platforms.parse_gid')
+    @patch('celery.platforms.setgid')
+    @patch('celery.platforms.setuid')
+    @patch('celery.platforms.initgroups')
+    def test_with_guid(self, initgroups, setuid, setgid,
+                       parse_gid, parse_uid):
 
 
-        def test_parse_uid_when_int(self):
-            self.assertEqual(parse_uid(5001), 5001)
+        def raise_on_second_call(*args, **kwargs):
+            setuid.side_effect = OSError()
+            setuid.side_effect.errno = errno.EPERM
+        setuid.side_effect = raise_on_second_call
+        parse_uid.return_value = 5001
+        parse_gid.return_value = 50001
+        maybe_drop_privileges(uid='user', gid='group')
+        parse_uid.assert_called_with('user')
+        parse_gid.assert_called_with('group')
+        setgid.assert_called_with(50001)
+        initgroups.assert_called_with(5001, 50001)
+        setuid.assert_has_calls([call(5001), call(0)])
+
+        setuid.side_effect = None
+        with self.assertRaises(RuntimeError):
+            maybe_drop_privileges(uid='user', gid='group')
+        setuid.side_effect = OSError()
+        setuid.side_effect.errno = errno.EINVAL
+        with self.assertRaises(OSError):
+            maybe_drop_privileges(uid='user', gid='group')
 
 
-        @patch('pwd.getpwnam')
-        def test_parse_uid_when_existing_name(self, getpwnam):
+    @patch('celery.platforms.setuid')
+    @patch('celery.platforms.setgid')
+    @patch('celery.platforms.parse_gid')
+    def test_only_gid(self, parse_gid, setgid, setuid):
+        parse_gid.return_value = 50001
+        maybe_drop_privileges(gid='group')
+        parse_gid.assert_called_with('group')
+        setgid.assert_called_with(50001)
+        self.assertFalse(setuid.called)
 
 
-            class pwent(object):
-                pw_uid = 5001
 
 
-            getpwnam.return_value = pwent()
-            self.assertEqual(parse_uid('user'), 5001)
+@skip_if_win32()
+class test_setget_uid_gid(Case):
 
 
-        @patch('pwd.getpwnam')
-        def test_parse_uid_when_nonexisting_name(self, getpwnam):
-            getpwnam.side_effect = KeyError('user')
+    @patch('celery.platforms.parse_uid')
+    @patch('os.setuid')
+    def test_setuid(self, _setuid, parse_uid):
+        parse_uid.return_value = 5001
+        setuid('user')
+        parse_uid.assert_called_with('user')
+        _setuid.assert_called_with(5001)
 
 
-            with self.assertRaises(KeyError):
-                parse_uid('user')
+    @patch('celery.platforms.parse_gid')
+    @patch('os.setgid')
+    def test_setgid(self, _setgid, parse_gid):
+        parse_gid.return_value = 50001
+        setgid('group')
+        parse_gid.assert_called_with('group')
+        _setgid.assert_called_with(50001)
 
 
-        def test_parse_gid_when_int(self):
-            self.assertEqual(parse_gid(50001), 50001)
+    def test_parse_uid_when_int(self):
+        self.assertEqual(parse_uid(5001), 5001)
 
 
-        @patch('grp.getgrnam')
-        def test_parse_gid_when_existing_name(self, getgrnam):
+    @patch('pwd.getpwnam')
+    def test_parse_uid_when_existing_name(self, getpwnam):
 
 
-            class grent(object):
-                gr_gid = 50001
+        class pwent(object):
+            pw_uid = 5001
+
+        getpwnam.return_value = pwent()
+        self.assertEqual(parse_uid('user'), 5001)
 
 
-            getgrnam.return_value = grent()
-            self.assertEqual(parse_gid('group'), 50001)
+    @patch('pwd.getpwnam')
+    def test_parse_uid_when_nonexisting_name(self, getpwnam):
+        getpwnam.side_effect = KeyError('user')
 
 
-        @patch('grp.getgrnam')
-        def test_parse_gid_when_nonexisting_name(self, getgrnam):
-            getgrnam.side_effect = KeyError('group')
+        with self.assertRaises(KeyError):
+            parse_uid('user')
 
 
-            with self.assertRaises(KeyError):
-                parse_gid('group')
+    def test_parse_gid_when_int(self):
+        self.assertEqual(parse_gid(50001), 50001)
 
 
-    class test_initgroups(Case):
+    @patch('grp.getgrnam')
+    def test_parse_gid_when_existing_name(self, getgrnam):
 
 
-        @patch('pwd.getpwuid')
-        @patch('os.initgroups', create=True)
-        def test_with_initgroups(self, initgroups_, getpwuid):
+        class grent(object):
+            gr_gid = 50001
+
+        getgrnam.return_value = grent()
+        self.assertEqual(parse_gid('group'), 50001)
+
+    @patch('grp.getgrnam')
+    def test_parse_gid_when_nonexisting_name(self, getgrnam):
+        getgrnam.side_effect = KeyError('group')
+
+        with self.assertRaises(KeyError):
+            parse_gid('group')
+
+
+@skip_if_win32()
+class test_initgroups(Case):
+
+    @patch('pwd.getpwuid')
+    @patch('os.initgroups', create=True)
+    def test_with_initgroups(self, initgroups_, getpwuid):
+        getpwuid.return_value = ['user']
+        initgroups(5001, 50001)
+        initgroups_.assert_called_with('user', 50001)
+
+    @patch('celery.platforms.setgroups')
+    @patch('grp.getgrall')
+    @patch('pwd.getpwuid')
+    def test_without_initgroups(self, getpwuid, getgrall, setgroups):
+        prev = getattr(os, 'initgroups', None)
+        try:
+            delattr(os, 'initgroups')
+        except AttributeError:
+            pass
+        try:
             getpwuid.return_value = ['user']
             getpwuid.return_value = ['user']
+
+            class grent(object):
+                gr_mem = ['user']
+
+                def __init__(self, gid):
+                    self.gr_gid = gid
+
+            getgrall.return_value = [grent(1), grent(2), grent(3)]
             initgroups(5001, 50001)
             initgroups(5001, 50001)
-            initgroups_.assert_called_with('user', 50001)
-
-        @patch('celery.platforms.setgroups')
-        @patch('grp.getgrall')
-        @patch('pwd.getpwuid')
-        def test_without_initgroups(self, getpwuid, getgrall, setgroups):
-            prev = getattr(os, 'initgroups', None)
-            try:
-                delattr(os, 'initgroups')
-            except AttributeError:
-                pass
-            try:
-                getpwuid.return_value = ['user']
-
-                class grent(object):
-                    gr_mem = ['user']
-
-                    def __init__(self, gid):
-                        self.gr_gid = gid
-
-                getgrall.return_value = [grent(1), grent(2), grent(3)]
-                initgroups(5001, 50001)
-                setgroups.assert_called_with([1, 2, 3])
-            finally:
-                if prev:
-                    os.initgroups = prev
-
-    class test_detached(Case):
-
-        def test_without_resource(self):
-            prev, platforms.resource = platforms.resource, None
-            try:
-                with self.assertRaises(RuntimeError):
-                    detached()
-            finally:
-                platforms.resource = prev
-
-        @patch('celery.platforms._create_pidlock')
-        @patch('celery.platforms.signals')
-        @patch('celery.platforms.maybe_drop_privileges')
-        @patch('os.geteuid')
-        @patch(open_fqdn)
-        def test_default(self, open, geteuid, maybe_drop,
-                         signals, pidlock):
-            geteuid.return_value = 0
-            context = detached(uid='user', gid='group')
-            self.assertIsInstance(context, DaemonContext)
-            signals.reset.assert_called_with('SIGCLD')
-            maybe_drop.assert_called_with(uid='user', gid='group')
-            open.return_value = Mock()
-
-            geteuid.return_value = 5001
-            context = detached(uid='user', gid='group', logfile='/foo/bar')
-            self.assertIsInstance(context, DaemonContext)
-            self.assertTrue(context.after_chdir)
-            context.after_chdir()
-            open.assert_called_with('/foo/bar', 'a')
-            open.return_value.close.assert_called_with()
-
-            context = detached(pidfile='/foo/bar/pid')
-            self.assertIsInstance(context, DaemonContext)
-            self.assertTrue(context.after_chdir)
-            context.after_chdir()
-            pidlock.assert_called_with('/foo/bar/pid')
-
-    class test_DaemonContext(Case):
-
-        @patch('os.fork')
-        @patch('os.setsid')
-        @patch('os._exit')
-        @patch('os.chdir')
-        @patch('os.umask')
-        @patch('os.close')
-        @patch('os.closerange')
-        @patch('os.open')
-        @patch('os.dup2')
-        def test_open(self, dup2, open, close, closer, umask, chdir,
-                      _exit, setsid, fork):
-            x = DaemonContext(workdir='/opt/workdir', umask=0o22)
-            x.stdfds = [0, 1, 2]
-
-            fork.return_value = 0
-            with x:
-                self.assertTrue(x._is_open)
-                with x:
-                    pass
-            self.assertEqual(fork.call_count, 2)
-            setsid.assert_called_with()
-            self.assertFalse(_exit.called)
-
-            chdir.assert_called_with(x.workdir)
-            umask.assert_called_with(0o22)
-            self.assertTrue(dup2.called)
-
-            fork.reset_mock()
-            fork.return_value = 1
-            x = DaemonContext(workdir='/opt/workdir')
-            x.stdfds = [0, 1, 2]
-            with x:
-                pass
-            self.assertEqual(fork.call_count, 1)
-            _exit.assert_called_with(0)
+            setgroups.assert_called_with([1, 2, 3])
+        finally:
+            if prev:
+                os.initgroups = prev
 
 
-            x = DaemonContext(workdir='/opt/workdir', fake=True)
-            x.stdfds = [0, 1, 2]
-            x._detach = Mock()
-            with x:
-                pass
-            self.assertFalse(x._detach.called)
 
 
-            x.after_chdir = Mock()
+@skip_if_win32()
+class test_detached(Case):
+
+    def test_without_resource(self):
+        prev, platforms.resource = platforms.resource, None
+        try:
+            with self.assertRaises(RuntimeError):
+                detached()
+        finally:
+            platforms.resource = prev
+
+    @patch('celery.platforms._create_pidlock')
+    @patch('celery.platforms.signals')
+    @patch('celery.platforms.maybe_drop_privileges')
+    @patch('os.geteuid')
+    @patch(open_fqdn)
+    def test_default(self, open, geteuid, maybe_drop,
+                     signals, pidlock):
+        geteuid.return_value = 0
+        context = detached(uid='user', gid='group')
+        self.assertIsInstance(context, DaemonContext)
+        signals.reset.assert_called_with('SIGCLD')
+        maybe_drop.assert_called_with(uid='user', gid='group')
+        open.return_value = Mock()
+
+        geteuid.return_value = 5001
+        context = detached(uid='user', gid='group', logfile='/foo/bar')
+        self.assertIsInstance(context, DaemonContext)
+        self.assertTrue(context.after_chdir)
+        context.after_chdir()
+        open.assert_called_with('/foo/bar', 'a')
+        open.return_value.close.assert_called_with()
+
+        context = detached(pidfile='/foo/bar/pid')
+        self.assertIsInstance(context, DaemonContext)
+        self.assertTrue(context.after_chdir)
+        context.after_chdir()
+        pidlock.assert_called_with('/foo/bar/pid')
+
+
+@skip_if_win32()
+class test_DaemonContext(Case):
+
+    @patch('os.fork')
+    @patch('os.setsid')
+    @patch('os._exit')
+    @patch('os.chdir')
+    @patch('os.umask')
+    @patch('os.close')
+    @patch('os.closerange')
+    @patch('os.open')
+    @patch('os.dup2')
+    def test_open(self, dup2, open, close, closer, umask, chdir,
+                  _exit, setsid, fork):
+        x = DaemonContext(workdir='/opt/workdir', umask=0o22)
+        x.stdfds = [0, 1, 2]
+
+        fork.return_value = 0
+        with x:
+            self.assertTrue(x._is_open)
             with x:
             with x:
                 pass
                 pass
-            x.after_chdir.assert_called_with()
-
-            x = DaemonContext(workdir='/opt/workdir', umask='0755')
-            self.assertEqual(x.umask, 493)
-            x = DaemonContext(workdir='/opt/workdir', umask='493')
-            self.assertEqual(x.umask, 493)
-
-            x.redirect_to_null(None)
-
-            with patch('celery.platforms.mputil') as mputil:
-                x = DaemonContext(after_forkers=True)
-                x.open()
-                mputil._run_after_forkers.assert_called_with()
-                x = DaemonContext(after_forkers=False)
-                x.open()
-
-    class test_Pidfile(Case):
-
-        @patch('celery.platforms.Pidfile')
-        def test_create_pidlock(self, Pidfile):
-            p = Pidfile.return_value = Mock()
-            p.is_locked.return_value = True
-            p.remove_if_stale.return_value = False
-            with override_stdouts() as (_, err):
-                with self.assertRaises(SystemExit):
-                    create_pidlock('/var/pid')
-                self.assertIn('already exists', err.getvalue())
-
-            p.remove_if_stale.return_value = True
-            ret = create_pidlock('/var/pid')
-            self.assertIs(ret, p)
-
-        def test_context(self):
-            p = Pidfile('/var/pid')
-            p.write_pid = Mock()
-            p.remove = Mock()
-
-            with p as _p:
-                self.assertIs(_p, p)
-            p.write_pid.assert_called_with()
-            p.remove.assert_called_with()
+        self.assertEqual(fork.call_count, 2)
+        setsid.assert_called_with()
+        self.assertFalse(_exit.called)
+
+        chdir.assert_called_with(x.workdir)
+        umask.assert_called_with(0o22)
+        self.assertTrue(dup2.called)
+
+        fork.reset_mock()
+        fork.return_value = 1
+        x = DaemonContext(workdir='/opt/workdir')
+        x.stdfds = [0, 1, 2]
+        with x:
+            pass
+        self.assertEqual(fork.call_count, 1)
+        _exit.assert_called_with(0)
+
+        x = DaemonContext(workdir='/opt/workdir', fake=True)
+        x.stdfds = [0, 1, 2]
+        x._detach = Mock()
+        with x:
+            pass
+        self.assertFalse(x._detach.called)
+
+        x.after_chdir = Mock()
+        with x:
+            pass
+        x.after_chdir.assert_called_with()
+
+        x = DaemonContext(workdir='/opt/workdir', umask='0755')
+        self.assertEqual(x.umask, 493)
+        x = DaemonContext(workdir='/opt/workdir', umask='493')
+        self.assertEqual(x.umask, 493)
+
+        x.redirect_to_null(None)
+
+        with patch('celery.platforms.mputil') as mputil:
+            x = DaemonContext(after_forkers=True)
+            x.open()
+            mputil._run_after_forkers.assert_called_with()
+            x = DaemonContext(after_forkers=False)
+            x.open()
+
+
+@skip_if_win32()
+class test_Pidfile(Case):
+
+    @patch('celery.platforms.Pidfile')
+    def test_create_pidlock(self, Pidfile):
+        p = Pidfile.return_value = Mock()
+        p.is_locked.return_value = True
+        p.remove_if_stale.return_value = False
+        with override_stdouts() as (_, err):
+            with self.assertRaises(SystemExit):
+                create_pidlock('/var/pid')
+            self.assertIn('already exists', err.getvalue())
+
+        p.remove_if_stale.return_value = True
+        ret = create_pidlock('/var/pid')
+        self.assertIs(ret, p)
+
+    def test_context(self):
+        p = Pidfile('/var/pid')
+        p.write_pid = Mock()
+        p.remove = Mock()
+
+        with p as _p:
+            self.assertIs(_p, p)
+        p.write_pid.assert_called_with()
+        p.remove.assert_called_with()
+
+    def test_acquire_raises_LockFailed(self):
+        p = Pidfile('/var/pid')
+        p.write_pid = Mock()
+        p.write_pid.side_effect = OSError()
+
+        with self.assertRaises(LockFailed):
+            with p:
+                pass
 
 
-        def test_acquire_raises_LockFailed(self):
+    @patch('os.path.exists')
+    def test_is_locked(self, exists):
+        p = Pidfile('/var/pid')
+        exists.return_value = True
+        self.assertTrue(p.is_locked())
+        exists.return_value = False
+        self.assertFalse(p.is_locked())
+
+    def test_read_pid(self):
+        with mock_open() as s:
+            s.write('1816\n')
+            s.seek(0)
             p = Pidfile('/var/pid')
             p = Pidfile('/var/pid')
-            p.write_pid = Mock()
-            p.write_pid.side_effect = OSError()
-
-            with self.assertRaises(LockFailed):
-                with p:
-                    pass
+            self.assertEqual(p.read_pid(), 1816)
 
 
-        @patch('os.path.exists')
-        def test_is_locked(self, exists):
-            p = Pidfile('/var/pid')
-            exists.return_value = True
-            self.assertTrue(p.is_locked())
-            exists.return_value = False
-            self.assertFalse(p.is_locked())
-
-        def test_read_pid(self):
-            with mock_open() as s:
-                s.write('1816\n')
-                s.seek(0)
-                p = Pidfile('/var/pid')
-                self.assertEqual(p.read_pid(), 1816)
-
-        def test_read_pid_partially_written(self):
-            with mock_open() as s:
-                s.write('1816')
-                s.seek(0)
-                p = Pidfile('/var/pid')
-                with self.assertRaises(ValueError):
-                    p.read_pid()
-
-        def test_read_pid_raises_ENOENT(self):
-            exc = IOError()
-            exc.errno = errno.ENOENT
-            with mock_open(side_effect=exc):
-                p = Pidfile('/var/pid')
-                self.assertIsNone(p.read_pid())
-
-        def test_read_pid_raises_IOError(self):
-            exc = IOError()
-            exc.errno = errno.EAGAIN
-            with mock_open(side_effect=exc):
-                p = Pidfile('/var/pid')
-                with self.assertRaises(IOError):
-                    p.read_pid()
-
-        def test_read_pid_bogus_pidfile(self):
-            with mock_open() as s:
-                s.write('eighteensixteen\n')
-                s.seek(0)
-                p = Pidfile('/var/pid')
-                with self.assertRaises(ValueError):
-                    p.read_pid()
-
-        @patch('os.unlink')
-        def test_remove(self, unlink):
-            unlink.return_value = True
+    def test_read_pid_partially_written(self):
+        with mock_open() as s:
+            s.write('1816')
+            s.seek(0)
             p = Pidfile('/var/pid')
             p = Pidfile('/var/pid')
-            p.remove()
-            unlink.assert_called_with(p.path)
+            with self.assertRaises(ValueError):
+                p.read_pid()
 
 
-        @patch('os.unlink')
-        def test_remove_ENOENT(self, unlink):
-            exc = OSError()
-            exc.errno = errno.ENOENT
-            unlink.side_effect = exc
+    def test_read_pid_raises_ENOENT(self):
+        exc = IOError()
+        exc.errno = errno.ENOENT
+        with mock_open(side_effect=exc):
             p = Pidfile('/var/pid')
             p = Pidfile('/var/pid')
-            p.remove()
-            unlink.assert_called_with(p.path)
+            self.assertIsNone(p.read_pid())
 
 
-        @patch('os.unlink')
-        def test_remove_EACCES(self, unlink):
-            exc = OSError()
-            exc.errno = errno.EACCES
-            unlink.side_effect = exc
+    def test_read_pid_raises_IOError(self):
+        exc = IOError()
+        exc.errno = errno.EAGAIN
+        with mock_open(side_effect=exc):
             p = Pidfile('/var/pid')
             p = Pidfile('/var/pid')
-            p.remove()
-            unlink.assert_called_with(p.path)
+            with self.assertRaises(IOError):
+                p.read_pid()
 
 
-        @patch('os.unlink')
-        def test_remove_OSError(self, unlink):
-            exc = OSError()
-            exc.errno = errno.EAGAIN
-            unlink.side_effect = exc
+    def test_read_pid_bogus_pidfile(self):
+        with mock_open() as s:
+            s.write('eighteensixteen\n')
+            s.seek(0)
             p = Pidfile('/var/pid')
             p = Pidfile('/var/pid')
-            with self.assertRaises(OSError):
-                p.remove()
-            unlink.assert_called_with(p.path)
-
-        @patch('os.kill')
-        def test_remove_if_stale_process_alive(self, kill):
+            with self.assertRaises(ValueError):
+                p.read_pid()
+
+    @patch('os.unlink')
+    def test_remove(self, unlink):
+        unlink.return_value = True
+        p = Pidfile('/var/pid')
+        p.remove()
+        unlink.assert_called_with(p.path)
+
+    @patch('os.unlink')
+    def test_remove_ENOENT(self, unlink):
+        exc = OSError()
+        exc.errno = errno.ENOENT
+        unlink.side_effect = exc
+        p = Pidfile('/var/pid')
+        p.remove()
+        unlink.assert_called_with(p.path)
+
+    @patch('os.unlink')
+    def test_remove_EACCES(self, unlink):
+        exc = OSError()
+        exc.errno = errno.EACCES
+        unlink.side_effect = exc
+        p = Pidfile('/var/pid')
+        p.remove()
+        unlink.assert_called_with(p.path)
+
+    @patch('os.unlink')
+    def test_remove_OSError(self, unlink):
+        exc = OSError()
+        exc.errno = errno.EAGAIN
+        unlink.side_effect = exc
+        p = Pidfile('/var/pid')
+        with self.assertRaises(OSError):
+            p.remove()
+        unlink.assert_called_with(p.path)
+
+    @patch('os.kill')
+    def test_remove_if_stale_process_alive(self, kill):
+        p = Pidfile('/var/pid')
+        p.read_pid = Mock()
+        p.read_pid.return_value = 1816
+        kill.return_value = 0
+        self.assertFalse(p.remove_if_stale())
+        kill.assert_called_with(1816, 0)
+        p.read_pid.assert_called_with()
+
+        kill.side_effect = OSError()
+        kill.side_effect.errno = errno.ENOENT
+        self.assertFalse(p.remove_if_stale())
+
+    @patch('os.kill')
+    def test_remove_if_stale_process_dead(self, kill):
+        with override_stdouts():
             p = Pidfile('/var/pid')
             p = Pidfile('/var/pid')
             p.read_pid = Mock()
             p.read_pid = Mock()
             p.read_pid.return_value = 1816
             p.read_pid.return_value = 1816
-            kill.return_value = 0
-            self.assertFalse(p.remove_if_stale())
+            p.remove = Mock()
+            exc = OSError()
+            exc.errno = errno.ESRCH
+            kill.side_effect = exc
+            self.assertTrue(p.remove_if_stale())
             kill.assert_called_with(1816, 0)
             kill.assert_called_with(1816, 0)
-            p.read_pid.assert_called_with()
-
-            kill.side_effect = OSError()
-            kill.side_effect.errno = errno.ENOENT
-            self.assertFalse(p.remove_if_stale())
-
-        @patch('os.kill')
-        def test_remove_if_stale_process_dead(self, kill):
-            with override_stdouts():
-                p = Pidfile('/var/pid')
-                p.read_pid = Mock()
-                p.read_pid.return_value = 1816
-                p.remove = Mock()
-                exc = OSError()
-                exc.errno = errno.ESRCH
-                kill.side_effect = exc
-                self.assertTrue(p.remove_if_stale())
-                kill.assert_called_with(1816, 0)
-                p.remove.assert_called_with()
-
-        def test_remove_if_stale_broken_pid(self):
-            with override_stdouts():
-                p = Pidfile('/var/pid')
-                p.read_pid = Mock()
-                p.read_pid.side_effect = ValueError()
-                p.remove = Mock()
-
-                self.assertTrue(p.remove_if_stale())
-                p.remove.assert_called_with()
-
-        def test_remove_if_stale_no_pidfile(self):
+            p.remove.assert_called_with()
+
+    def test_remove_if_stale_broken_pid(self):
+        with override_stdouts():
             p = Pidfile('/var/pid')
             p = Pidfile('/var/pid')
             p.read_pid = Mock()
             p.read_pid = Mock()
-            p.read_pid.return_value = None
+            p.read_pid.side_effect = ValueError()
             p.remove = Mock()
             p.remove = Mock()
 
 
             self.assertTrue(p.remove_if_stale())
             self.assertTrue(p.remove_if_stale())
             p.remove.assert_called_with()
             p.remove.assert_called_with()
 
 
-        @patch('os.fsync')
-        @patch('os.getpid')
-        @patch('os.open')
-        @patch('os.fdopen')
-        @patch(open_fqdn)
-        def test_write_pid(self, open_, fdopen, osopen, getpid, fsync):
-            getpid.return_value = 1816
-            osopen.return_value = 13
-            w = fdopen.return_value = WhateverIO()
-            w.close = Mock()
-            r = open_.return_value = WhateverIO()
-            r.write('1816\n')
-            r.seek(0)
-
-            p = Pidfile('/var/pid')
+    def test_remove_if_stale_no_pidfile(self):
+        p = Pidfile('/var/pid')
+        p.read_pid = Mock()
+        p.read_pid.return_value = None
+        p.remove = Mock()
+
+        self.assertTrue(p.remove_if_stale())
+        p.remove.assert_called_with()
+
+    @patch('os.fsync')
+    @patch('os.getpid')
+    @patch('os.open')
+    @patch('os.fdopen')
+    @patch(open_fqdn)
+    def test_write_pid(self, open_, fdopen, osopen, getpid, fsync):
+        getpid.return_value = 1816
+        osopen.return_value = 13
+        w = fdopen.return_value = WhateverIO()
+        w.close = Mock()
+        r = open_.return_value = WhateverIO()
+        r.write('1816\n')
+        r.seek(0)
+
+        p = Pidfile('/var/pid')
+        p.write_pid()
+        w.seek(0)
+        self.assertEqual(w.readline(), '1816\n')
+        self.assertTrue(w.close.called)
+        getpid.assert_called_with()
+        osopen.assert_called_with(
+            p.path, platforms.PIDFILE_FLAGS, platforms.PIDFILE_MODE,
+        )
+        fdopen.assert_called_with(13, 'w')
+        fsync.assert_called_with(13)
+        open_.assert_called_with(p.path)
+
+    @patch('os.fsync')
+    @patch('os.getpid')
+    @patch('os.open')
+    @patch('os.fdopen')
+    @patch(open_fqdn)
+    def test_write_reread_fails(self, open_, fdopen,
+                                osopen, getpid, fsync):
+        getpid.return_value = 1816
+        osopen.return_value = 13
+        w = fdopen.return_value = WhateverIO()
+        w.close = Mock()
+        r = open_.return_value = WhateverIO()
+        r.write('11816\n')
+        r.seek(0)
+
+        p = Pidfile('/var/pid')
+        with self.assertRaises(LockFailed):
             p.write_pid()
             p.write_pid()
-            w.seek(0)
-            self.assertEqual(w.readline(), '1816\n')
-            self.assertTrue(w.close.called)
-            getpid.assert_called_with()
-            osopen.assert_called_with(p.path, platforms.PIDFILE_FLAGS,
-                                      platforms.PIDFILE_MODE)
-            fdopen.assert_called_with(13, 'w')
-            fsync.assert_called_with(13)
-            open_.assert_called_with(p.path)
-
-        @patch('os.fsync')
-        @patch('os.getpid')
-        @patch('os.open')
-        @patch('os.fdopen')
-        @patch(open_fqdn)
-        def test_write_reread_fails(self, open_, fdopen,
-                                    osopen, getpid, fsync):
-            getpid.return_value = 1816
-            osopen.return_value = 13
-            w = fdopen.return_value = WhateverIO()
-            w.close = Mock()
-            r = open_.return_value = WhateverIO()
-            r.write('11816\n')
-            r.seek(0)
 
 
-            p = Pidfile('/var/pid')
-            with self.assertRaises(LockFailed):
-                p.write_pid()
 
 
-    class test_setgroups(Case):
+class test_setgroups(Case):
 
 
-        @patch('os.setgroups', create=True)
-        def test_setgroups_hack_ValueError(self, setgroups):
+    @patch('os.setgroups', create=True)
+    def test_setgroups_hack_ValueError(self, setgroups):
 
 
-            def on_setgroups(groups):
-                if len(groups) <= 200:
-                    setgroups.return_value = True
-                    return
-                raise ValueError()
-            setgroups.side_effect = on_setgroups
+        def on_setgroups(groups):
+            if len(groups) <= 200:
+                setgroups.return_value = True
+                return
+            raise ValueError()
+        setgroups.side_effect = on_setgroups
+        _setgroups_hack(list(range(400)))
+
+        setgroups.side_effect = ValueError()
+        with self.assertRaises(ValueError):
             _setgroups_hack(list(range(400)))
             _setgroups_hack(list(range(400)))
 
 
-            setgroups.side_effect = ValueError()
-            with self.assertRaises(ValueError):
-                _setgroups_hack(list(range(400)))
+    @patch('os.setgroups', create=True)
+    def test_setgroups_hack_OSError(self, setgroups):
+        exc = OSError()
+        exc.errno = errno.EINVAL
 
 
-        @patch('os.setgroups', create=True)
-        def test_setgroups_hack_OSError(self, setgroups):
-            exc = OSError()
-            exc.errno = errno.EINVAL
+        def on_setgroups(groups):
+            if len(groups) <= 200:
+                setgroups.return_value = True
+                return
+            raise exc
+        setgroups.side_effect = on_setgroups
 
 
-            def on_setgroups(groups):
-                if len(groups) <= 200:
-                    setgroups.return_value = True
-                    return
-                raise exc
-            setgroups.side_effect = on_setgroups
+        _setgroups_hack(list(range(400)))
 
 
+        setgroups.side_effect = exc
+        with self.assertRaises(OSError):
             _setgroups_hack(list(range(400)))
             _setgroups_hack(list(range(400)))
 
 
-            setgroups.side_effect = exc
-            with self.assertRaises(OSError):
-                _setgroups_hack(list(range(400)))
+        exc2 = OSError()
+        exc.errno = errno.ESRCH
+        setgroups.side_effect = exc2
+        with self.assertRaises(OSError):
+            _setgroups_hack(list(range(400)))
 
 
-            exc2 = OSError()
-            exc.errno = errno.ESRCH
-            setgroups.side_effect = exc2
-            with self.assertRaises(OSError):
-                _setgroups_hack(list(range(400)))
-
-        @patch('os.sysconf')
-        @patch('celery.platforms._setgroups_hack')
-        def test_setgroups(self, hack, sysconf):
-            sysconf.return_value = 100
+    @patch('os.sysconf')
+    @patch('celery.platforms._setgroups_hack')
+    def test_setgroups(self, hack, sysconf):
+        sysconf.return_value = 100
+        setgroups(list(range(400)))
+        hack.assert_called_with(list(range(100)))
+
+    @patch('os.sysconf')
+    @patch('celery.platforms._setgroups_hack')
+    def test_setgroups_sysconf_raises(self, hack, sysconf):
+        sysconf.side_effect = ValueError()
+        setgroups(list(range(400)))
+        hack.assert_called_with(list(range(400)))
+
+    @patch('os.getgroups')
+    @patch('os.sysconf')
+    @patch('celery.platforms._setgroups_hack')
+    def test_setgroups_raises_ESRCH(self, hack, sysconf, getgroups):
+        sysconf.side_effect = ValueError()
+        esrch = OSError()
+        esrch.errno = errno.ESRCH
+        hack.side_effect = esrch
+        with self.assertRaises(OSError):
             setgroups(list(range(400)))
             setgroups(list(range(400)))
-            hack.assert_called_with(list(range(100)))
 
 
-        @patch('os.sysconf')
-        @patch('celery.platforms._setgroups_hack')
-        def test_setgroups_sysconf_raises(self, hack, sysconf):
-            sysconf.side_effect = ValueError()
-            setgroups(list(range(400)))
-            hack.assert_called_with(list(range(400)))
-
-        @patch('os.getgroups')
-        @patch('os.sysconf')
-        @patch('celery.platforms._setgroups_hack')
-        def test_setgroups_raises_ESRCH(self, hack, sysconf, getgroups):
-            sysconf.side_effect = ValueError()
-            esrch = OSError()
-            esrch.errno = errno.ESRCH
-            hack.side_effect = esrch
-            with self.assertRaises(OSError):
-                setgroups(list(range(400)))
-
-        @patch('os.getgroups')
-        @patch('os.sysconf')
-        @patch('celery.platforms._setgroups_hack')
-        def test_setgroups_raises_EPERM(self, hack, sysconf, getgroups):
-            sysconf.side_effect = ValueError()
-            eperm = OSError()
-            eperm.errno = errno.EPERM
-            hack.side_effect = eperm
-            getgroups.return_value = list(range(400))
+    @patch('os.getgroups')
+    @patch('os.sysconf')
+    @patch('celery.platforms._setgroups_hack')
+    def test_setgroups_raises_EPERM(self, hack, sysconf, getgroups):
+        sysconf.side_effect = ValueError()
+        eperm = OSError()
+        eperm.errno = errno.EPERM
+        hack.side_effect = eperm
+        getgroups.return_value = list(range(400))
+        setgroups(list(range(400)))
+        getgroups.assert_called_with()
+
+        getgroups.return_value = [1000]
+        with self.assertRaises(OSError):
             setgroups(list(range(400)))
             setgroups(list(range(400)))
-            getgroups.assert_called_with()
-
-            getgroups.return_value = [1000]
-            with self.assertRaises(OSError):
-                setgroups(list(range(400)))
-            getgroups.assert_called_with()
+        getgroups.assert_called_with()
 
 
 
 
 class test_check_privileges(Case):
 class test_check_privileges(Case):

+ 3 - 9
celery/tests/utils/test_sysinfo.py

@@ -1,17 +1,14 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
-import os
-
 from celery.utils.sysinfo import load_average, df
 from celery.utils.sysinfo import load_average, df
 
 
-from celery.tests.case import Case, SkipTest, patch
+from celery.tests.case import Case, patch, skip_unless_symbol
 
 
 
 
+@skip_unless_symbol('os.getloadavg')
 class test_load_average(Case):
 class test_load_average(Case):
 
 
     def test_avg(self):
     def test_avg(self):
-        if not hasattr(os, 'getloadavg'):
-            raise SkipTest('getloadavg not available')
         with patch('os.getloadavg') as getloadavg:
         with patch('os.getloadavg') as getloadavg:
             getloadavg.return_value = 0.54736328125, 0.6357421875, 0.69921875
             getloadavg.return_value = 0.54736328125, 0.6357421875, 0.69921875
             l = load_average()
             l = load_average()
@@ -19,13 +16,10 @@ class test_load_average(Case):
             self.assertEqual(l, (0.55, 0.64, 0.7))
             self.assertEqual(l, (0.55, 0.64, 0.7))
 
 
 
 
+@skip_unless_symbol('posix.statvfs_result')
 class test_df(Case):
 class test_df(Case):
 
 
     def test_df(self):
     def test_df(self):
-        try:
-            from posix import statvfs_result  # noqa
-        except ImportError:
-            raise SkipTest('statvfs not available')
         x = df('/')
         x = df('/')
         self.assertTrue(x.total_blocks)
         self.assertTrue(x.total_blocks)
         self.assertTrue(x.available)
         self.assertTrue(x.available)

+ 2 - 4
celery/tests/utils/test_term.py

@@ -7,15 +7,13 @@ from celery.utils import term
 from celery.utils.term import colored, fg
 from celery.utils.term import colored, fg
 from celery.five import text_t
 from celery.five import text_t
 
 
-from celery.tests.case import Case, SkipTest
+from celery.tests.case import Case, skip_if_win32
 
 
 
 
+@skip_if_win32()
 class test_colored(Case):
 class test_colored(Case):
 
 
     def setUp(self):
     def setUp(self):
-        if sys.platform == 'win32':
-            raise SkipTest('Colors not supported on Windows')
-
         self._prev_encoding = sys.getdefaultencoding
         self._prev_encoding = sys.getdefaultencoding
 
 
         def getdefaultencoding():
         def getdefaultencoding():

+ 2 - 4
celery/tests/worker/test_components.py

@@ -5,10 +5,9 @@ from __future__ import absolute_import, unicode_literals
 # point [-ask]
 # point [-ask]
 
 
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
-from celery.platforms import IS_WINDOWS
 from celery.worker.components import Beat, Hub, Pool, Timer
 from celery.worker.components import Beat, Hub, Pool, Timer
 
 
-from celery.tests.case import AppCase, Mock, SkipTest, patch
+from celery.tests.case import AppCase, Mock, patch, skip_if_win32
 
 
 
 
 class test_Timer(AppCase):
 class test_Timer(AppCase):
@@ -61,9 +60,8 @@ class test_Pool(AppCase):
         comp.close(w)
         comp.close(w)
         comp.terminate(w)
         comp.terminate(w)
 
 
+    @skip_if_win32()
     def test_create_when_eventloop(self):
     def test_create_when_eventloop(self):
-        if IS_WINDOWS:
-            raise SkipTest('Win32')
         w = Mock()
         w = Mock()
         w.use_eventloop = w.pool_putlocks = w.pool_cls.uses_semaphore = True
         w.use_eventloop = w.pool_putlocks = w.pool_cls.uses_semaphore = True
         comp = Pool(w)
         comp = Pool(w)

+ 5 - 6
celery/tests/worker/test_consumer.py

@@ -13,7 +13,9 @@ from celery.worker.consumer.heart import Heart
 from celery.worker.consumer.mingle import Mingle
 from celery.worker.consumer.mingle import Mingle
 from celery.worker.consumer.tasks import Tasks
 from celery.worker.consumer.tasks import Tasks
 
 
-from celery.tests.case import AppCase, ContextMock, Mock, SkipTest, call, patch
+from celery.tests.case import (
+    AppCase, ContextMock, Mock, call, patch, skip_if_python3,
+)
 
 
 
 
 class test_Consumer(AppCase):
 class test_Consumer(AppCase):
@@ -43,14 +45,11 @@ class test_Consumer(AppCase):
         c = self.get_consumer()
         c = self.get_consumer()
         self.assertIsNone(c.task_buckets['fooxasdwx.wewe'])
         self.assertIsNone(c.task_buckets['fooxasdwx.wewe'])
 
 
+    @skip_if_python3(reason='buffer type not available')
     def test_dump_body_buffer(self):
     def test_dump_body_buffer(self):
         msg = Mock()
         msg = Mock()
         msg.body = 'str'
         msg.body = 'str'
-        try:
-            buf = buffer(msg.body)
-        except NameError:
-            raise SkipTest('buffer type not available')
-        self.assertTrue(dump_body(msg, buf))
+        self.assertTrue(dump_body(msg, buffer(msg.body)))
 
 
     def test_sets_heartbeat(self):
     def test_sets_heartbeat(self):
         c = self.get_consumer(amqheartbeat=10)
         c = self.get_consumer(amqheartbeat=10)

+ 7 - 11
celery/tests/worker/test_request.py

@@ -45,11 +45,10 @@ from celery.tests.case import (
     AppCase,
     AppCase,
     Case,
     Case,
     Mock,
     Mock,
-    SkipTest,
     TaskMessage,
     TaskMessage,
-    assert_signal_called,
     task_message_from_sig,
     task_message_from_sig,
     patch,
     patch,
+    skip_if_python3,
 )
 )
 
 
 
 
@@ -124,12 +123,9 @@ def jail(app, task_id, name, args, kwargs):
     ).retval
     ).retval
 
 
 
 
+@skip_if_python3
 class test_default_encode(AppCase):
 class test_default_encode(AppCase):
 
 
-    def setup(self):
-        if sys.version_info >= (3, 0):
-            raise SkipTest('py3k: not relevant')
-
     def test_jython(self):
     def test_jython(self):
         prev, sys.platform = sys.platform, 'java 1.6.1'
         prev, sys.platform = sys.platform, 'java 1.6.1'
         try:
         try:
@@ -430,7 +426,7 @@ class test_Request(RequestCase):
         signum = signal.SIGTERM
         signum = signal.SIGTERM
         job = self.get_request(self.mytask.s(1, f='x'))
         job = self.get_request(self.mytask.s(1, f='x'))
         job._apply_result = Mock(name='_apply_result')
         job._apply_result = Mock(name='_apply_result')
-        with assert_signal_called(
+        with self.assert_signal_called(
                 task_revoked, sender=job.task, request=job,
                 task_revoked, sender=job.task, request=job,
                 terminated=True, expired=False, signum=signum):
                 terminated=True, expired=False, signum=signum):
             job.time_start = monotonic()
             job.time_start = monotonic()
@@ -446,7 +442,7 @@ class test_Request(RequestCase):
         pool = Mock()
         pool = Mock()
         signum = signal.SIGTERM
         signum = signal.SIGTERM
         job = self.get_request(self.mytask.s(1, f='x'))
         job = self.get_request(self.mytask.s(1, f='x'))
-        with assert_signal_called(
+        with self.assert_signal_called(
                 task_revoked, sender=job.task, request=job,
                 task_revoked, sender=job.task, request=job,
                 terminated=True, expired=False, signum=signum):
                 terminated=True, expired=False, signum=signum):
             job.time_start = monotonic()
             job.time_start = monotonic()
@@ -467,7 +463,7 @@ class test_Request(RequestCase):
         job = self.get_request(self.mytask.s(1, f='x').set(
         job = self.get_request(self.mytask.s(1, f='x').set(
             expires=datetime.utcnow() - timedelta(days=1)
             expires=datetime.utcnow() - timedelta(days=1)
         ))
         ))
-        with assert_signal_called(
+        with self.assert_signal_called(
                 task_revoked, sender=job.task, request=job,
                 task_revoked, sender=job.task, request=job,
                 terminated=False, expired=True, signum=None):
                 terminated=False, expired=True, signum=None):
             job.revoked()
             job.revoked()
@@ -506,7 +502,7 @@ class test_Request(RequestCase):
 
 
     def test_revoked(self):
     def test_revoked(self):
         job = self.xRequest()
         job = self.xRequest()
-        with assert_signal_called(
+        with self.assert_signal_called(
                 task_revoked, sender=job.task, request=job,
                 task_revoked, sender=job.task, request=job,
                 terminated=False, expired=False, signum=None):
                 terminated=False, expired=False, signum=None):
             revoked.add(job.id)
             revoked.add(job.id)
@@ -555,7 +551,7 @@ class test_Request(RequestCase):
         signum = signal.SIGTERM
         signum = signal.SIGTERM
         pool = Mock()
         pool = Mock()
         job = self.xRequest()
         job = self.xRequest()
-        with assert_signal_called(
+        with self.assert_signal_called(
                 task_revoked, sender=job.task, request=job,
                 task_revoked, sender=job.task, request=job,
                 terminated=True, expired=False, signum=signum):
                 terminated=True, expired=False, signum=signum):
             job.terminate(pool, signal='TERM')
             job.terminate(pool, signal='TERM')

+ 2 - 2
celery/tests/worker/test_worker.py

@@ -30,7 +30,7 @@ from celery.utils import worker_direct
 from celery.utils.serialization import pickle
 from celery.utils.serialization import pickle
 from celery.utils.timer2 import Timer
 from celery.utils.timer2 import Timer
 
 
-from celery.tests.case import AppCase, Mock, SkipTest, TaskMessage, patch
+from celery.tests.case import AppCase, Mock, TaskMessage, patch, todo
 
 
 
 
 def MockStep(step=None):
 def MockStep(step=None):
@@ -849,8 +849,8 @@ class test_WorkController(AppCase):
             self.worker._send_worker_shutdown()
             self.worker._send_worker_shutdown()
             ws.send.assert_called_with(sender=self.worker)
             ws.send.assert_called_with(sender=self.worker)
 
 
+    @todo('unstable test')
     def test_process_shutdown_on_worker_shutdown(self):
     def test_process_shutdown_on_worker_shutdown(self):
-        raise SkipTest('unstable test')
         from celery.concurrency.prefork import process_destructor
         from celery.concurrency.prefork import process_destructor
         from celery.concurrency.asynpool import Worker
         from celery.concurrency.asynpool import Worker
         with patch('celery.signals.worker_process_shutdown') as ws:
         with patch('celery.signals.worker_process_shutdown') as ws: