Selaa lähdekoodia

[tests] Cleanup SkipTest stuff

Ask Solem 9 vuotta sitten
vanhempi
commit
1a572fb55c
40 muutettua tiedostoa jossa 1549 lisäystä ja 1540 poistoa
  1. 799 0
      celery/tests/_case.py
  2. 2 2
      celery/tests/app/test_app.py
  3. 2 6
      celery/tests/app/test_beat.py
  4. 2 2
      celery/tests/app/test_loaders.py
  5. 6 6
      celery/tests/app/test_log.py
  6. 3 6
      celery/tests/app/test_schedules.py
  7. 4 3
      celery/tests/backends/test_base.py
  8. 2 2
      celery/tests/backends/test_cache.py
  9. 2 52
      celery/tests/backends/test_couchbase.py
  10. 2 3
      celery/tests/backends/test_couchdb.py
  11. 6 13
      celery/tests/backends/test_database.py
  12. 2 8
      celery/tests/backends/test_elasticsearch.py
  13. 2 4
      celery/tests/backends/test_filesystem.py
  14. 7 13
      celery/tests/backends/test_mongodb.py
  15. 6 8
      celery/tests/backends/test_redis.py
  16. 3 4
      celery/tests/backends/test_riak.py
  17. 2 1
      celery/tests/bin/test_amqp.py
  18. 3 2
      celery/tests/bin/test_celery.py
  19. 2 1
      celery/tests/bin/test_celeryevdump.py
  20. 2 6
      celery/tests/bin/test_events.py
  21. 3 3
      celery/tests/bin/test_multi.py
  22. 35 61
      celery/tests/bin/test_worker.py
  23. 33 682
      celery/tests/case.py
  24. 1 2
      celery/tests/concurrency/test_eventlet.py
  25. 1 1
      celery/tests/concurrency/test_gevent.py
  26. 2 5
      celery/tests/concurrency/test_pool.py
  27. 7 16
      celery/tests/concurrency/test_prefork.py
  28. 4 3
      celery/tests/contrib/test_rdb.py
  29. 2 6
      celery/tests/events/test_cursesmon.py
  30. 2 2
      celery/tests/events/test_state.py
  31. 3 7
      celery/tests/security/case.py
  32. 2 2
      celery/tests/security/test_certificate.py
  33. 3 11
      celery/tests/utils/test_datastructures.py
  34. 571 561
      celery/tests/utils/test_platforms.py
  35. 3 9
      celery/tests/utils/test_sysinfo.py
  36. 2 4
      celery/tests/utils/test_term.py
  37. 2 4
      celery/tests/worker/test_components.py
  38. 5 6
      celery/tests/worker/test_consumer.py
  39. 7 11
      celery/tests/worker/test_request.py
  40. 2 2
      celery/tests/worker/test_worker.py

+ 799 - 0
celery/tests/_case.py

@@ -0,0 +1,799 @@
+from __future__ import absolute_import, unicode_literals
+
+import importlib
+import inspect
+import io
+import logging
+import os
+import platform
+import re
+import sys
+import time
+import types
+import warnings
+
+from contextlib import contextmanager
+from functools import partial, wraps
+from six import (
+    iteritems as items,
+    itervalues as values,
+    string_types,
+    reraise,
+)
+from six.moves import builtins
+
+from nose import SkipTest
+
+try:
+    import unittest  # noqa
+    unittest.skip
+    from unittest.util import safe_repr, unorderable_list_difference
+except AttributeError:
+    import unittest2 as unittest  # noqa
+    from unittest2.util import safe_repr, unorderable_list_difference  # noqa
+
+try:
+    from unittest import mock
+except ImportError:
+    import mock  # noqa
+
+__all__ = [
+    'ANY', 'Case', 'ContextMock', 'MagicMock', 'Mock', 'MockCallbacks',
+    'call', 'patch', 'sentinel',
+
+    'mock_open', 'mock_context', 'mock_module',
+    'patch_modules', 'reset_modules', 'sys_platform', 'pypy_version',
+    'platform_pyimp', 'replace_module_value', 'override_stdouts',
+    'mask_modules', 'sleepdeprived', 'mock_environ', 'wrap_logger',
+    'restore_logging',
+
+    'todo', 'skip', 'skip_if_darwin', 'skip_if_environ',
+    'skip_if_jython', 'skip_if_platform', 'skip_if_pypy', 'skip_if_python3',
+    'skip_if_win32', 'skip_unless_module', 'skip_unless_symbol',
+]
+
+patch = mock.patch
+call = mock.call
+sentinel = mock.sentinel
+MagicMock = mock.MagicMock
+ANY = mock.ANY
+
+PY3 = sys.version_info[0] == 3
+if PY3:
+    open_fqdn = 'builtins.open'
+    module_name_t = str
+else:
+    open_fqdn = '__builtin__.open'  # noqa
+    module_name_t = bytes  # noqa
+
+StringIO = io.StringIO
+_SIO_write = StringIO.write
+_SIO_init = StringIO.__init__
+
+
+def symbol_by_name(name, aliases={}, imp=None, package=None,
+                   sep='.', default=None, **kwargs):
+    """Get symbol by qualified name.
+
+    The name should be the full dot-separated path to the class::
+
+        modulename.ClassName
+
+    Example::
+
+        celery.concurrency.processes.TaskPool
+                                    ^- class name
+
+    or using ':' to separate module and symbol::
+
+        celery.concurrency.processes:TaskPool
+
+    If `aliases` is provided, a dict containing short name/long name
+    mappings, the name is looked up in the aliases first.
+
+    Examples:
+
+        >>> symbol_by_name('celery.concurrency.processes.TaskPool')
+        <class 'celery.concurrency.processes.TaskPool'>
+
+        >>> symbol_by_name('default', {
+        ...     'default': 'celery.concurrency.processes.TaskPool'})
+        <class 'celery.concurrency.processes.TaskPool'>
+
+        # Does not try to look up non-string names.
+        >>> from celery.concurrency.processes import TaskPool
+        >>> symbol_by_name(TaskPool) is TaskPool
+        True
+
+    """
+    if imp is None:
+        imp = importlib.import_module
+
+    if not isinstance(name, string_types):
+        return name                                 # already a class
+
+    name = aliases.get(name) or name
+    sep = ':' if ':' in name else sep
+    module_name, _, cls_name = name.rpartition(sep)
+    if not module_name:
+        cls_name, module_name = None, package if package else cls_name
+    try:
+        try:
+            module = imp(module_name, package=package, **kwargs)
+        except ValueError as exc:
+            reraise(ValueError,
+                    ValueError("Couldn't import {0!r}: {1}".format(name, exc)),
+                    sys.exc_info()[2])
+        return getattr(module, cls_name) if cls_name else module
+    except (ImportError, AttributeError):
+        if default is None:
+            raise
+    return default
+
+
+class WhateverIO(StringIO):
+
+    def __init__(self, v=None, *a, **kw):
+        _SIO_init(self, v.decode() if isinstance(v, bytes) else v, *a, **kw)
+
+    def write(self, data):
+        _SIO_write(self, data.decode() if isinstance(data, bytes) else data)
+
+
+def noop(*args, **kwargs):
+    pass
+
+
+class Mock(mock.Mock):
+
+    def __init__(self, *args, **kwargs):
+        attrs = kwargs.pop('attrs', None) or {}
+        super(Mock, self).__init__(*args, **kwargs)
+        for attr_name, attr_value in items(attrs):
+            setattr(self, attr_name, attr_value)
+
+
+class _ContextMock(Mock):
+    """Dummy class implementing __enter__ and __exit__
+    as the :keyword:`with` statement requires these to be implemented
+    in the class, not just the instance."""
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, *exc_info):
+        pass
+
+
+def ContextMock(*args, **kwargs):
+    obj = _ContextMock(*args, **kwargs)
+    obj.attach_mock(_ContextMock(), '__enter__')
+    obj.attach_mock(_ContextMock(), '__exit__')
+    obj.__enter__.return_value = obj
+    # if __exit__ return a value the exception is ignored,
+    # so it must return None here.
+    obj.__exit__.return_value = None
+    return obj
+
+
+def _bind(f, o):
+    @wraps(f)
+    def bound_meth(*fargs, **fkwargs):
+        return f(o, *fargs, **fkwargs)
+    return bound_meth
+
+
+if PY3:  # pragma: no cover
+    def _get_class_fun(meth):
+        return meth
+else:
+    def _get_class_fun(meth):
+        return meth.__func__
+
+
+class MockCallbacks(object):
+
+    def __new__(cls, *args, **kwargs):
+        r = Mock(name=cls.__name__)
+        _get_class_fun(cls.__init__)(r, *args, **kwargs)
+        for key, value in items(vars(cls)):
+            if key not in ('__dict__', '__weakref__', '__new__', '__init__'):
+                if inspect.ismethod(value) or inspect.isfunction(value):
+                    r.__getattr__(key).side_effect = _bind(value, r)
+                else:
+                    r.__setattr__(key, value)
+        return r
+
+
+# -- adds assertWarns from recent unittest2, not in Python 2.7.
+
+class _AssertRaisesBaseContext(object):
+
+    def __init__(self, expected, test_case, callable_obj=None,
+                 expected_regex=None):
+        self.expected = expected
+        self.failureException = test_case.failureException
+        self.obj_name = None
+        if isinstance(expected_regex, string_types):
+            expected_regex = re.compile(expected_regex)
+        self.expected_regex = expected_regex
+
+
+def _is_magic_module(m):
+    # some libraries create custom module types that are lazily
+    # lodaded, e.g. Django installs some modules in sys.modules that
+    # will load _tkinter and other shit when touched.
+
+    # pyflakes refuses to accept 'noqa' for this isinstance.
+    cls, modtype = type(m), types.ModuleType
+    try:
+        variables = vars(cls)
+    except TypeError:
+        return True
+    else:
+        return (cls is not modtype and (
+            '__getattr__' in variables or
+            '__getattribute__' in variables))
+
+
+class _AssertWarnsContext(_AssertRaisesBaseContext):
+    """A context manager used to implement TestCase.assertWarns* methods."""
+
+    def __enter__(self):
+        # The __warningregistry__'s need to be in a pristine state for tests
+        # to work properly.
+        warnings.resetwarnings()
+        for v in list(values(sys.modules)):
+            # do not evaluate Django moved modules and other lazily
+            # initialized modules.
+            if v and not _is_magic_module(v):
+                # use raw __getattribute__ to protect even better from
+                # lazily loaded modules
+                try:
+                    object.__getattribute__(v, '__warningregistry__')
+                except AttributeError:
+                    pass
+                else:
+                    object.__setattr__(v, '__warningregistry__', {})
+        self.warnings_manager = warnings.catch_warnings(record=True)
+        self.warnings = self.warnings_manager.__enter__()
+        warnings.simplefilter('always', self.expected)
+        return self
+
+    def __exit__(self, exc_type, exc_value, tb):
+        self.warnings_manager.__exit__(exc_type, exc_value, tb)
+        if exc_type is not None:
+            # let unexpected exceptions pass through
+            return
+        try:
+            exc_name = self.expected.__name__
+        except AttributeError:
+            exc_name = str(self.expected)
+        first_matching = None
+        for m in self.warnings:
+            w = m.message
+            if not isinstance(w, self.expected):
+                continue
+            if first_matching is None:
+                first_matching = w
+            if (self.expected_regex is not None and
+                    not self.expected_regex.search(str(w))):
+                continue
+            # store warning for later retrieval
+            self.warning = w
+            self.filename = m.filename
+            self.lineno = m.lineno
+            return
+        # Now we simply try to choose a helpful failure message
+        if first_matching is not None:
+            raise self.failureException(
+                '%r does not match %r' % (
+                    self.expected_regex.pattern, str(first_matching)))
+        if self.obj_name:
+            raise self.failureException(
+                '%s not triggered by %s' % (exc_name, self.obj_name))
+        else:
+            raise self.failureException('%s not triggered' % exc_name)
+
+
+class Case(unittest.TestCase):
+    DeprecationWarning = DeprecationWarning
+    PendingDeprecationWarning = PendingDeprecationWarning
+
+    def patch(self, *path, **options):
+        manager = patch('.'.join(path), **options)
+        patched = manager.start()
+        self.addCleanup(manager.stop)
+        return patched
+
+    def mock_modules(self, *mods):
+        modules = []
+        for mod in mods:
+            mod = mod.split('.')
+            modules.extend(reversed([
+                '.'.join(mod[:-i] if i else mod) for i in range(len(mod))
+            ]))
+        modules = sorted(set(modules))
+        return self.wrap_context(mock_module(*modules))
+
+    def on_nth_call_do(self, mock, side_effect, n=1):
+
+        def on_call(*args, **kwargs):
+            if mock.call_count >= n:
+                mock.side_effect = side_effect
+            return mock.return_value
+        mock.side_effect = on_call
+        return mock
+
+    def on_nth_call_return(self, mock, retval, n=1):
+
+        def on_call(*args, **kwargs):
+            if mock.call_count >= n:
+                mock.return_value = retval
+            return mock.return_value
+        mock.side_effect = on_call
+        return mock
+
+    def mask_modules(self, *modules):
+        self.wrap_context(mask_modules(*modules))
+
+    def wrap_context(self, context):
+        ret = context.__enter__()
+        self.addCleanup(partial(context.__exit__, None, None, None))
+        return ret
+
+    def mock_environ(self, env_name, env_value):
+        return self.wrap_context(mock_environ(env_name, env_value))
+
+    def assertWarns(self, expected_warning):
+        return _AssertWarnsContext(expected_warning, self, None)
+
+    def assertWarnsRegex(self, expected_warning, expected_regex):
+        return _AssertWarnsContext(expected_warning, self,
+                                   None, expected_regex)
+
+    @contextmanager
+    def assertDeprecated(self):
+        with self.assertWarnsRegex(self.DeprecationWarning,
+                                   r'scheduled for removal'):
+            yield
+
+    @contextmanager
+    def assertPendingDeprecation(self):
+        with self.assertWarnsRegex(self.PendingDeprecationWarning,
+                                   r'scheduled for deprecation'):
+            yield
+
+    def assertDictContainsSubset(self, expected, actual, msg=None):
+        missing, mismatched = [], []
+
+        for key, value in items(expected):
+            if key not in actual:
+                missing.append(key)
+            elif value != actual[key]:
+                mismatched.append('%s, expected: %s, actual: %s' % (
+                    safe_repr(key), safe_repr(value),
+                    safe_repr(actual[key])))
+
+        if not (missing or mismatched):
+            return
+
+        standard_msg = ''
+        if missing:
+            standard_msg = 'Missing: %s' % ','.join(map(safe_repr, missing))
+
+        if mismatched:
+            if standard_msg:
+                standard_msg += '; '
+            standard_msg += 'Mismatched values: %s' % (
+                ','.join(mismatched))
+
+        self.fail(self._formatMessage(msg, standard_msg))
+
+    def assertItemsEqual(self, expected_seq, actual_seq, msg=None):
+        missing = unexpected = None
+        try:
+            expected = sorted(expected_seq)
+            actual = sorted(actual_seq)
+        except TypeError:
+            # Unsortable items (example: set(), complex(), ...)
+            expected = list(expected_seq)
+            actual = list(actual_seq)
+            missing, unexpected = unorderable_list_difference(
+                expected, actual)
+        else:
+            return self.assertSequenceEqual(expected, actual, msg=msg)
+
+        errors = []
+        if missing:
+            errors.append(
+                'Expected, but missing:\n    %s' % (safe_repr(missing),)
+            )
+        if unexpected:
+            errors.append(
+                'Unexpected, but present:\n    %s' % (safe_repr(unexpected),)
+            )
+        if errors:
+            standardMsg = '\n'.join(errors)
+            self.fail(self._formatMessage(msg, standardMsg))
+
+
+class _CallableContext(object):
+
+    def __init__(self, context, cargs, ckwargs, fun):
+        self.context = context
+        self.cargs = cargs
+        self.ckwargs = ckwargs
+        self.fun = fun
+
+    def __call__(self, *args, **kwargs):
+        return self.fun(*args, **kwargs)
+
+    def __enter__(self):
+        self.ctx = self.context(*self.cargs, **self.ckwargs)
+        return self.ctx.__enter__()
+
+    def __exit__(self, *einfo):
+        if self.ctx:
+            return self.ctx.__exit__(*einfo)
+
+
+def decorator(predicate):
+
+    @wraps(predicate)
+    def take_arguments(*pargs, **pkwargs):
+
+        @wraps(predicate)
+        def decorator(cls):
+            if inspect.isclass(cls):
+                orig_setup = cls.setUp
+                orig_teardown = cls.tearDown
+
+                @wraps(cls.setUp)
+                def around_setup(*args, **kwargs):
+                    try:
+                        contexts = args[0].__rb3dc_contexts__
+                    except AttributeError:
+                        contexts = args[0].__rb3dc_contexts__ = []
+                    p = predicate(*pargs, **pkwargs)
+                    p.__enter__()
+                    contexts.append(p)
+                    return orig_setup(*args, **kwargs)
+                around_setup.__wrapped__ = cls.setUp
+                cls.setUp = around_setup
+
+                @wraps(cls.tearDown)
+                def around_teardown(*args, **kwargs):
+                    try:
+                        contexts = args[0].__rb3dc_contexts__
+                    except AttributeError:
+                        pass
+                    else:
+                        for context in contexts:
+                            context.__exit__(*sys.exc_info())
+                    orig_teardown(*args, **kwargs)
+                around_teardown.__wrapped__ = cls.tearDown
+                cls.tearDown = around_teardown
+
+                return cls
+            else:
+                @wraps(cls)
+                def around_case(*args, **kwargs):
+                    with predicate(*pargs, **pkwargs):
+                        return cls(*args, **kwargs)
+                return around_case
+
+        if len(pargs) == 1 and callable(pargs[0]):
+            fun, pargs = pargs[0], ()
+            return decorator(fun)
+        return _CallableContext(predicate, pargs, pkwargs, decorator)
+    return take_arguments
+
+
+@decorator
+@contextmanager
+def skip_unless_module(module, name=None):
+    try:
+        importlib.import_module(module)
+    except (ImportError, OSError):
+        raise SkipTest('module not installed: {0}'.format(name or module))
+    yield
+
+
+@decorator
+@contextmanager
+def skip_unless_symbol(symbol, name=None):
+    try:
+        symbol_by_name(symbol)
+    except (AttributeError, ImportError):
+        raise SkipTest('missing symbol {0}'.format(name or symbol))
+    yield
+
+
+def get_logger_handlers(logger):
+    return [
+        h for h in logger.handlers
+        if not isinstance(h, logging.NullHandler)
+    ]
+
+
+@decorator
+@contextmanager
+def wrap_logger(logger, loglevel=logging.ERROR):
+    old_handlers = get_logger_handlers(logger)
+    sio = WhateverIO()
+    siohandler = logging.StreamHandler(sio)
+    logger.handlers = [siohandler]
+
+    try:
+        yield sio
+    finally:
+        logger.handlers = old_handlers
+
+
+@decorator
+@contextmanager
+def mock_environ(env_name, env_value):
+    sentinel = object()
+    prev_val = os.environ.get(env_name, sentinel)
+    os.environ[env_name] = env_value
+    try:
+        yield env_value
+    finally:
+        if prev_val is sentinel:
+            os.environ.pop(env_name, None)
+        else:
+            os.environ[env_name] = prev_val
+
+
+@decorator
+@contextmanager
+def sleepdeprived(module=time):
+    old_sleep, module.sleep = module.sleep, noop
+    try:
+        yield
+    finally:
+        module.sleep = old_sleep
+
+
+@decorator
+@contextmanager
+def skip_if_python3(reason='incompatible'):
+    if PY3:
+        raise SkipTest('Python3: {0}'.format(reason))
+    yield
+
+
+@decorator
+@contextmanager
+def skip_if_environ(env_var_name):
+    if os.environ.get(env_var_name):
+        raise SkipTest('envvar {0} set'.format(env_var_name))
+    yield
+
+
+@decorator
+@contextmanager
+def _skip_test(reason, sign):
+    raise SkipTest('{0}: {1}'.format(sign, reason))
+    yield
+todo = partial(_skip_test, sign='TODO')
+skip = partial(_skip_test, sign='SKIP')
+
+
+# Taken from
+# http://bitbucket.org/runeh/snippets/src/tip/missing_modules.py
+@decorator
+@contextmanager
+def mask_modules(*modnames):
+    """Ban some modules from being importable inside the context
+
+    For example:
+
+        >>> with mask_modules('sys'):
+        ...     try:
+        ...         import sys
+        ...     except ImportError:
+        ...         print('sys not found')
+        sys not found
+
+        >>> import sys  # noqa
+        >>> sys.version
+        (2, 5, 2, 'final', 0)
+
+    """
+    realimport = builtins.__import__
+
+    def myimp(name, *args, **kwargs):
+        if name in modnames:
+            raise ImportError('No module named %s' % name)
+        else:
+            return realimport(name, *args, **kwargs)
+
+    builtins.__import__ = myimp
+    try:
+        yield True
+    finally:
+        builtins.__import__ = realimport
+
+
+@decorator
+@contextmanager
+def override_stdouts():
+    """Override `sys.stdout` and `sys.stderr` with `WhateverIO`."""
+    prev_out, prev_err = sys.stdout, sys.stderr
+    prev_rout, prev_rerr = sys.__stdout__, sys.__stderr__
+    mystdout, mystderr = WhateverIO(), WhateverIO()
+    sys.stdout = sys.__stdout__ = mystdout
+    sys.stderr = sys.__stderr__ = mystderr
+
+    try:
+        yield mystdout, mystderr
+    finally:
+        sys.stdout = prev_out
+        sys.stderr = prev_err
+        sys.__stdout__ = prev_rout
+        sys.__stderr__ = prev_rerr
+
+
+@decorator
+@contextmanager
+def replace_module_value(module, name, value=None):
+    has_prev = hasattr(module, name)
+    prev = getattr(module, name, None)
+    if value:
+        setattr(module, name, value)
+    else:
+        try:
+            delattr(module, name)
+        except AttributeError:
+            pass
+    try:
+        yield
+    finally:
+        if prev is not None:
+            setattr(module, name, prev)
+        if not has_prev:
+            try:
+                delattr(module, name)
+            except AttributeError:
+                pass
+pypy_version = partial(
+    replace_module_value, sys, 'pypy_version_info',
+)
+platform_pyimp = partial(
+    replace_module_value, platform, 'python_implementation',
+)
+
+
+@decorator
+@contextmanager
+def sys_platform(value):
+    prev, sys.platform = sys.platform, value
+    try:
+        yield
+    finally:
+        sys.platform = prev
+
+
+@decorator
+@contextmanager
+def reset_modules(*modules):
+    prev = {k: sys.modules.pop(k) for k in modules if k in sys.modules}
+    try:
+        yield
+    finally:
+        sys.modules.update(prev)
+
+
+@decorator
+@contextmanager
+def patch_modules(*modules):
+    prev = {}
+    for mod in modules:
+        prev[mod] = sys.modules.get(mod)
+        sys.modules[mod] = types.ModuleType(module_name_t(mod))
+    try:
+        yield
+    finally:
+        for name, mod in items(prev):
+            if mod is None:
+                sys.modules.pop(name, None)
+            else:
+                sys.modules[name] = mod
+
+
+@decorator
+@contextmanager
+def mock_module(*names):
+    prev = {}
+
+    class MockModule(types.ModuleType):
+
+        def __getattr__(self, attr):
+            setattr(self, attr, Mock())
+            return types.ModuleType.__getattribute__(self, attr)
+
+    mods = []
+    for name in names:
+        try:
+            prev[name] = sys.modules[name]
+        except KeyError:
+            pass
+        mod = sys.modules[name] = MockModule(module_name_t(name))
+        mods.append(mod)
+    try:
+        yield mods
+    finally:
+        for name in names:
+            try:
+                sys.modules[name] = prev[name]
+            except KeyError:
+                try:
+                    del(sys.modules[name])
+                except KeyError:
+                    pass
+
+
+@contextmanager
+def mock_context(mock, typ=Mock):
+    context = mock.return_value = Mock()
+    context.__enter__ = typ()
+    context.__exit__ = typ()
+
+    def on_exit(*x):
+        if x[0]:
+            reraise(x[0], x[1], x[2])
+    context.__exit__.side_effect = on_exit
+    context.__enter__.return_value = context
+    try:
+        yield context
+    finally:
+        context.reset()
+
+
+@decorator
+@contextmanager
+def mock_open(typ=WhateverIO, side_effect=None):
+    with patch(open_fqdn) as open_:
+        with mock_context(open_) as context:
+            if side_effect is not None:
+                context.__enter__.side_effect = side_effect
+            val = context.__enter__.return_value = typ()
+            val.__exit__ = Mock()
+            yield val
+
+
+@decorator
+@contextmanager
+def skip_if_platform(platform_name, name=None):
+    if sys.platform.startswith(platform_name):
+        raise SkipTest('does not work on {0}'.format(platform_name or name))
+    yield
+skip_if_jython = partial(skip_if_platform, 'java', name='Jython')
+skip_if_win32 = partial(skip_if_platform, 'win32', name='Windows')
+skip_if_darwin = partial(skip_if_platform, 'darwin', name='OS X')
+
+
+@decorator
+@contextmanager
+def skip_if_pypy():
+    if getattr(sys, 'pypy_version_info', None):
+        raise SkipTest('does not work on PyPy')
+    yield
+
+
+@decorator
+@contextmanager
+def restore_logging():
+    outs = sys.stdout, sys.stderr, sys.__stdout__, sys.__stderr__
+    root = logging.getLogger()
+    level = root.level
+    handlers = root.handlers
+
+    try:
+        yield
+    finally:
+        sys.stdout, sys.stderr, sys.__stdout__, sys.__stderr__ = outs
+        root.level = level
+        root.handlers[:] = handlers

+ 2 - 2
celery/tests/app/test_app.py

@@ -34,7 +34,7 @@ from celery.tests.case import (
     platform_pyimp,
     sys_platform,
     pypy_version,
-    with_environ,
+    mock_environ,
 )
 from celery.utils import uuid
 from celery.utils.mail import ErrorMail
@@ -236,7 +236,7 @@ class test_App(AppCase):
             ['A', 'B', 'C', 'D', 'E', 'F'], related_name='tasks',
         )
 
-    @with_environ('CELERY_BROKER_URL', '')
+    @mock_environ('CELERY_BROKER_URL', '')
     def test_with_broker(self):
         with self.Celery(broker='foo://baribaz') as app:
             self.assertEqual(app.conf.broker_url, 'foo://baribaz')

+ 2 - 6
celery/tests/app/test_beat.py

@@ -11,7 +11,7 @@ from celery.schedules import schedule
 from celery.utils import uuid
 from celery.utils.objects import Bunch
 
-from celery.tests.case import AppCase, Mock, SkipTest, call, patch
+from celery.tests.case import AppCase, Mock, call, patch, skip_unless_module
 
 
 class MockShelve(dict):
@@ -485,12 +485,8 @@ class test_Service(AppCase):
 
 class test_EmbeddedService(AppCase):
 
+    @skip_unless_module('_multiprocessing', name='multiprocessing')
     def test_start_stop_process(self):
-        try:
-            import _multiprocessing  # noqa
-        except ImportError:
-            raise SkipTest('multiprocessing not available')
-
         from billiard.process import Process
 
         s = beat.EmbeddedService(self.app)

+ 2 - 2
celery/tests/app/test_loaders.py

@@ -13,7 +13,7 @@ from celery.loaders.app import AppLoader
 from celery.utils.imports import NotAPackage
 from celery.utils.mail import SendmailWarning
 
-from celery.tests.case import AppCase, Case, Mock, patch, with_environ
+from celery.tests.case import AppCase, Case, Mock, mock_environ, patch
 
 
 class DummyLoader(base.BaseLoader):
@@ -144,7 +144,7 @@ class test_DefaultLoader(AppCase):
             l.read_configuration(fail_silently=False)
 
     @patch('celery.loaders.base.find_module')
-    @with_environ('CELERY_CONFIG_MODULE', 'celeryconfig.py')
+    @mock_environ('CELERY_CONFIG_MODULE', 'celeryconfig.py')
     def test_read_configuration_py_in_name(self, find_module):
         find_module.side_effect = NotAPackage()
         l = default.Loader(app=self.app)

+ 6 - 6
celery/tests/app/test_log.py

@@ -21,9 +21,10 @@ from celery.utils.log import (
     logger_isa,
 )
 from celery.tests.case import (
-    AppCase, Mock, SkipTest, mask_modules,
-    get_handlers, override_stdouts, patch, wrap_logger, restore_logging,
+    AppCase, Mock, mask_modules, skip_if_python3,
+    override_stdouts, patch, wrap_logger, restore_logging,
 )
+from celery.tests._case import get_logger_handlers
 
 
 class test_TaskFormatter(AppCase):
@@ -155,10 +156,9 @@ class test_ColorFormatter(AppCase):
         self.assertIn('<Unrepresentable', msg)
         self.assertEqual(safe_str.call_count, 1)
 
+    @skip_if_python3()
     @patch('celery.utils.log.safe_str')
     def test_format_raises_no_color(self, safe_str):
-        if sys.version_info[0] == 3:
-            raise SkipTest('py3k')
         x = ColorFormatter(use_color=False)
         record = Mock()
         record.levelname = 'ERROR'
@@ -235,7 +235,7 @@ class test_default_logger(AppCase):
             logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
                                        root=False, colorize=None)
             self.assertIs(
-                get_handlers(logger)[0].stream, sys.__stderr__,
+                get_logger_handlers(logger)[0].stream, sys.__stderr__,
                 'setup_logger logs to stderr without logfile argument.',
             )
 
@@ -273,7 +273,7 @@ class test_default_logger(AppCase):
                     logfile=tempfile, loglevel=logging.INFO, root=False,
                 )
                 self.assertIsInstance(
-                    get_handlers(l)[0], logging.FileHandler,
+                    get_logger_handlers(l)[0], logging.FileHandler,
                 )
                 self.assertIn(tempfile, files)
 

+ 3 - 6
celery/tests/app/test_schedules.py

@@ -10,7 +10,7 @@ from celery.five import items
 from celery.schedules import (
     ParseException, crontab, crontab_parser, schedule, solar,
 )
-from celery.tests.case import AppCase, Mock, SkipTest
+from celery.tests.case import AppCase, Mock, skip_unless_module, todo
 
 
 @contextmanager
@@ -23,13 +23,10 @@ def patch_crontab_nowfun(cls, retval):
         cls.nowfun = prev_nowfun
 
 
+@skip_unless_module('ephem')
 class test_solar(AppCase):
 
     def setup(self):
-        try:
-            import ephem  # noqa
-        except ImportError:
-            raise SkipTest('ephem module not installed')
         self.s = solar('sunrise', 60, 30, app=self.app)
 
     def test_reduce(self):
@@ -738,8 +735,8 @@ class test_crontab_is_due(AppCase):
             self.assertTrue(due)
             self.assertEqual(remaining, 60.)
 
+    @todo('unstable test')
     def test_monthly_moy_execution_is_not_due(self):
-        raise SkipTest('unstable test')
         with patch_crontab_nowfun(
                 self.monthly_moy, datetime(2013, 6, 28, 14, 30)):
             due, remaining = self.monthly_moy.is_due(

+ 4 - 3
celery/tests/backends/test_base.py

@@ -25,7 +25,9 @@ from celery.result import result_from_tuple
 from celery.utils import uuid
 from celery.utils.functional import pass1
 
-from celery.tests.case import ANY, AppCase, Case, Mock, SkipTest, call, patch
+from celery.tests.case import (
+    ANY, AppCase, Case, Mock, call, patch, skip_if_python3,
+)
 
 
 class wrapobject(object):
@@ -92,9 +94,8 @@ class test_BaseBackend_interface(AppCase):
 
 class test_exception_pickle(AppCase):
 
+    @skip_if_python3('does not support old style classes')
     def test_oldstyle(self):
-        if Oldstyle is None:
-            raise SkipTest('py3k does not support old style classes')
         self.assertTrue(fnpe(Oldstyle()))
 
     def test_BaseException(self):

+ 2 - 2
celery/tests/backends/test_cache.py

@@ -16,7 +16,7 @@ from celery.five import items, module_name_t, string, text_t
 from celery.utils import uuid
 
 from celery.tests.case import (
-    AppCase, Mock, disable_stdouts, mask_modules, patch, reset_modules,
+    AppCase, Mock, override_stdouts, mask_modules, patch, reset_modules,
 )
 
 PY3 = sys.version_info[0] == 3
@@ -136,7 +136,7 @@ class test_CacheBackend(AppCase):
         b = CacheBackend(backend=backend, app=self.app)
         self.assertEqual(b.as_uri(), backend)
 
-    @disable_stdouts
+    @override_stdouts
     def test_regression_worker_startup_info(self):
         self.app.conf.result_backend = (
             'cache+memcached://127.0.0.1:11211;127.0.0.2:11211;127.0.0.3/'

+ 2 - 52
celery/tests/backends/test_couchbase.py

@@ -9,7 +9,7 @@ from celery.backends.couchbase import CouchBaseBackend
 from celery.exceptions import ImproperlyConfigured
 from celery import backends
 from celery.tests.case import (
-    AppCase, MagicMock, Mock, SkipTest, patch, sentinel,
+    AppCase, MagicMock, Mock, patch, sentinel, skip_unless_module,
 )
 
 try:
@@ -20,24 +20,13 @@ except ImportError:
 COUCHBASE_BUCKET = 'celery_bucket'
 
 
+@skip_unless_module('couchbase')
 class test_CouchBaseBackend(AppCase):
 
-    """CouchBaseBackend TestCase."""
-
     def setup(self):
-        """Skip the test if couchbase cannot be imported."""
-        if couchbase is None:
-            raise SkipTest('couchbase is not installed.')
         self.backend = CouchBaseBackend(app=self.app)
 
     def test_init_no_couchbase(self):
-        """
-        Test init no couchbase raises.
-
-        If celery.backends.couchbase cannot import the couchbase client, it
-        sets the couchbase.Couchbase to None and then handles this in the
-        CouchBaseBackend __init__ method.
-        """
         prev, module.Couchbase = module.Couchbase, None
         try:
             with self.assertRaises(ImproperlyConfigured):
@@ -46,18 +35,15 @@ class test_CouchBaseBackend(AppCase):
             module.Couchbase = prev
 
     def test_init_no_settings(self):
-        """Test init no settings."""
         self.app.conf.couchbase_backend_settings = []
         with self.assertRaises(ImproperlyConfigured):
             CouchBaseBackend(app=self.app)
 
     def test_init_settings_is_None(self):
-        """Test init settings is None."""
         self.app.conf.couchbase_backend_settings = None
         CouchBaseBackend(app=self.app)
 
     def test_get_connection_connection_exists(self):
-        """Test _get_connection works."""
         with patch('couchbase.connection.Connection') as mock_Connection:
             self.backend._connection = sentinel._connection
 
@@ -67,14 +53,6 @@ class test_CouchBaseBackend(AppCase):
             self.assertFalse(mock_Connection.called)
 
     def test_get(self):
-        """
-        Test get method.
-
-        CouchBaseBackend.get should return  and take two params
-        db conn to couchbase is mocked.
-
-        TODO Should test on key not exists
-        """
         self.app.conf.couchbase_backend_settings = {}
         x = CouchBaseBackend(app=self.app)
         x._connection = Mock()
@@ -85,12 +63,6 @@ class test_CouchBaseBackend(AppCase):
         x._connection.get.assert_called_once_with('1f3fab')
 
     def test_set(self):
-        """
-        Test set method.
-
-        CouchBaseBackend.set should return None and take two params
-        db conn to couchbase is mocked.
-        """
         self.app.conf.couchbase_backend_settings = None
         x = CouchBaseBackend(app=self.app)
         x._connection = MagicMock()
@@ -99,14 +71,6 @@ class test_CouchBaseBackend(AppCase):
         self.assertIsNone(x.set(sentinel.key, sentinel.value))
 
     def test_delete(self):
-        """
-        Test delete method.
-
-        CouchBaseBackend.delete should return and take two params
-        db conn to couchbase is mocked.
-
-        TODO Should test on key not exists.
-        """
         self.app.conf.couchbase_backend_settings = {}
         x = CouchBaseBackend(app=self.app)
         x._connection = Mock()
@@ -117,11 +81,6 @@ class test_CouchBaseBackend(AppCase):
         x._connection.delete.assert_called_once_with('1f3fab')
 
     def test_config_params(self):
-        """
-        Test config params are correct.
-
-        app.conf.couchbase_backend_settings is properly set.
-        """
         self.app.conf.couchbase_backend_settings = {
             'bucket': 'mycoolbucket',
             'host': ['here.host.com', 'there.host.com'],
@@ -137,14 +96,12 @@ class test_CouchBaseBackend(AppCase):
         self.assertEqual(x.port, 1234)
 
     def test_backend_by_url(self, url='couchbase://myhost/mycoolbucket'):
-        """Test that a CouchBaseBackend is loaded from the couchbase url."""
         from celery.backends.couchbase import CouchBaseBackend
         backend, url_ = backends.get_backend_by_url(url, self.app.loader)
         self.assertIs(backend, CouchBaseBackend)
         self.assertEqual(url_, url)
 
     def test_backend_params_by_url(self):
-        """Test config params are correct from config url."""
         url = 'couchbase://johndoe:mysecret@myhost:123/mycoolbucket'
         with self.Celery(backend=url) as app:
             x = app.backend
@@ -155,13 +112,6 @@ class test_CouchBaseBackend(AppCase):
             self.assertEqual(x.port, 123)
 
     def test_correct_key_types(self):
-        """
-        Test that the key is the correct type for the couchbase python API.
-
-        We check that get_key_for_task, get_key_for_chord, and
-        get_key_for_group always returns a python string. Need to use str_t
-        for cross Python reasons.
-        """
         keys = [
             self.backend.get_key_for_task('task_id', bytes('key')),
             self.backend.get_key_for_chord('group_id', bytes('key')),

+ 2 - 3
celery/tests/backends/test_couchdb.py

@@ -5,7 +5,7 @@ from celery.backends.couchdb import CouchBackend
 from celery.exceptions import ImproperlyConfigured
 from celery import backends
 from celery.tests.case import (
-    AppCase, Mock, SkipTest, patch, sentinel,
+    AppCase, Mock, patch, sentinel, skip_unless_module,
 )
 
 try:
@@ -16,11 +16,10 @@ except ImportError:
 COUCHDB_CONTAINER = 'celery_container'
 
 
+@skip_unless_module('pycouchdb')
 class test_CouchBackend(AppCase):
 
     def setup(self):
-        if pycouchdb is None:
-            raise SkipTest('pycouchdb is not installed.')
         self.backend = CouchBackend(app=self.app)
 
     def test_init_no_pycouchdb(self):

+ 6 - 13
celery/tests/backends/test_database.py

@@ -11,11 +11,11 @@ from celery.utils import uuid
 from celery.tests.case import (
     AppCase,
     Mock,
-    SkipTest,
     depends_on_current_app,
     patch,
     skip_if_pypy,
     skip_if_jython,
+    skip_unless_module,
 )
 
 try:
@@ -38,12 +38,9 @@ class SomeClass(object):
         self.data = data
 
 
+@skip_unless_module('sqlalchemy')
 class test_session_cleanup(AppCase):
 
-    def setup(self):
-        if session_cleanup is None:
-            raise SkipTest('slqlalchemy not installed')
-
     def test_context(self):
         session = Mock(name='session')
         with session_cleanup(session):
@@ -59,13 +56,12 @@ class test_session_cleanup(AppCase):
         session.close.assert_called_with()
 
 
+@skip_unless_module('sqlalchemy')
+@skip_if_pypy()
+@skip_if_jython()
 class test_DatabaseBackend(AppCase):
 
-    @skip_if_pypy
-    @skip_if_jython
     def setup(self):
-        if DatabaseBackend is None:
-            raise SkipTest('sqlalchemy not installed')
         self.uri = 'sqlite:///test.db'
         self.app.conf.result_serializer = 'pickle'
 
@@ -218,12 +214,9 @@ class test_DatabaseBackend(AppCase):
         self.assertIn('foo', repr(TaskSet('foo', None)))
 
 
+@skip_unless_module('sqlalchemy')
 class test_SessionManager(AppCase):
 
-    def setup(self):
-        if SessionManager is None:
-            raise SkipTest('sqlalchemy not installed')
-
     def test_after_fork(self):
         s = SessionManager()
         self.assertFalse(s.forked)

+ 2 - 8
celery/tests/backends/test_elasticsearch.py

@@ -5,19 +5,13 @@ from celery.backends import elasticsearch as module
 from celery.backends.elasticsearch import ElasticsearchBackend
 from celery.exceptions import ImproperlyConfigured
 
-from celery.tests.case import AppCase, Mock, SkipTest, sentinel
-
-try:
-    import elasticsearch
-except ImportError:
-    elasticsearch = None
+from celery.tests.case import AppCase, Mock, sentinel, skip_unless_module
 
 
+@skip_unless_module('elasticsearch')
 class test_ElasticsearchBackend(AppCase):
 
     def setup(self):
-        if elasticsearch is None:
-            raise SkipTest('elasticsearch is not installed.')
         self.backend = ElasticsearchBackend(app=self.app)
 
     def test_init_no_elasticsearch(self):

+ 2 - 4
celery/tests/backends/test_filesystem.py

@@ -3,7 +3,6 @@ from __future__ import absolute_import, unicode_literals
 
 import os
 import shutil
-import sys
 import tempfile
 
 from celery import states
@@ -11,14 +10,13 @@ from celery.backends.filesystem import FilesystemBackend
 from celery.exceptions import ImproperlyConfigured
 from celery.utils import uuid
 
-from celery.tests.case import AppCase, SkipTest
+from celery.tests.case import AppCase, skip_if_win32
 
 
+@skip_if_win32()
 class test_FilesystemBackend(AppCase):
 
     def setup(self):
-        if sys.platform == 'win32':
-            raise SkipTest('win32: skip')
         self.directory = tempfile.mkdtemp()
         self.url = 'file://' + self.directory
         self.path = self.directory.encode('ascii')

+ 7 - 13
celery/tests/backends/test_mongodb.py

@@ -9,13 +9,12 @@ from kombu.exceptions import EncodeError
 from celery import uuid
 from celery import states
 from celery.backends import mongodb as module
-from celery.backends.mongodb import (
-    InvalidDocument, MongoBackend, pymongo,
-)
+from celery.backends.mongodb import InvalidDocument, MongoBackend
 from celery.exceptions import ImproperlyConfigured
 from celery.tests.case import (
-    AppCase, MagicMock, Mock, SkipTest, ANY,
-    depends_on_current_app, disable_stdouts, patch, sentinel,
+    AppCase, MagicMock, Mock, ANY,
+    depends_on_current_app, override_stdouts, patch, sentinel,
+    skip_unless_module,
 )
 
 COLLECTION = 'taskmeta_celery'
@@ -29,6 +28,7 @@ MONGODB_COLLECTION = 'collection1'
 MONGODB_GROUP_COLLECTION = 'group_collection1'
 
 
+@skip_unless_module('pymongo')
 class test_MongoBackend(AppCase):
 
     default_url = 'mongodb://uuuu:pwpw@hostname.dom/database'
@@ -43,9 +43,6 @@ class test_MongoBackend(AppCase):
     )
 
     def setup(self):
-        if pymongo is None:
-            raise SkipTest('pymongo is not installed.')
-
         R = self._reset = {}
         R['encode'], MongoBackend.encode = MongoBackend.encode, Mock()
         R['decode'], MongoBackend.decode = MongoBackend.decode, Mock()
@@ -410,7 +407,7 @@ class test_MongoBackend(AppCase):
         backend = MongoBackend(app=self.app, url=self.replica_set_url)
         self.assertEqual(backend.as_uri(), self.sanitized_replica_set_url)
 
-    @disable_stdouts
+    @override_stdouts
     def test_regression_worker_startup_info(self):
         self.app.conf.result_backend = (
             'mongodb://user:password@host0.com:43437,host1.com:43437'
@@ -421,12 +418,9 @@ class test_MongoBackend(AppCase):
         self.assertTrue(worker.startup_info())
 
 
+@skip_unless_module('pymongo')
 class test_MongoBackend_no_mock(AppCase):
 
-    def setup(self):
-        if pymongo is None:
-            raise SkipTest('pymongo is not installed.')
-
     def test_encode_decode(self):
         backend = MongoBackend(app=self.app)
         data = {'foo': 1}

+ 6 - 8
celery/tests/backends/test_redis.py

@@ -13,8 +13,8 @@ from celery.datastructures import AttributeDict
 from celery.exceptions import ChordError, ImproperlyConfigured
 
 from celery.tests.case import (
-    ANY, AppCase, ContextMock, Mock, MockCallbacks, SkipTest,
-    call, depends_on_current_app, patch,
+    ANY, AppCase, ContextMock, Mock, MockCallbacks,
+    call, depends_on_current_app, patch, skip_unless_module,
 )
 
 
@@ -142,13 +142,11 @@ class test_RedisBackend(AppCase):
         self.b = self.Backend(app=self.app)
 
     @depends_on_current_app
+    @skip_unless_module('redis')
     def test_reduce(self):
-        try:
-            from celery.backends.redis import RedisBackend
-            x = RedisBackend(app=self.app)
-            self.assertTrue(loads(dumps(x)))
-        except ImportError:
-            raise SkipTest('redis not installed')
+        from celery.backends.redis import RedisBackend
+        x = RedisBackend(app=self.app)
+        self.assertTrue(loads(dumps(x)))
 
     def test_no_redis(self):
         self.Backend.redis = None

+ 3 - 4
celery/tests/backends/test_riak.py

@@ -3,21 +3,20 @@
 from __future__ import absolute_import, unicode_literals
 
 from celery.backends import riak as module
-from celery.backends.riak import RiakBackend, riak
+from celery.backends.riak import RiakBackend
 from celery.exceptions import ImproperlyConfigured
 from celery.tests.case import (
-    AppCase, MagicMock, Mock, SkipTest, patch, sentinel,
+    AppCase, MagicMock, Mock, patch, sentinel, skip_unless_module,
 )
 
 
 RIAK_BUCKET = 'riak_bucket'
 
 
+@skip_unless_module('riak')
 class test_RiakBackend(AppCase):
 
     def setup(self):
-        if riak is None:
-            raise SkipTest('riak is not installed.')
         self.app.conf.result_backend = 'riak://'
 
     @property

+ 2 - 1
celery/tests/bin/test_amqp.py

@@ -7,8 +7,9 @@ from celery.bin.amqp import (
     amqp,
     main,
 )
+from celery.five import WhateverIO
 
-from celery.tests.case import AppCase, Mock, WhateverIO, patch
+from celery.tests.case import AppCase, Mock, patch
 
 
 class test_AMQShell(AppCase):

+ 3 - 2
celery/tests/bin/test_celery.py

@@ -7,7 +7,6 @@ from datetime import datetime
 from kombu.utils.json import dumps
 
 from celery import __main__
-from celery.platforms import EX_FAILURE, EX_USAGE, EX_OK
 from celery.bin.base import Error
 from celery.bin import celery as mod
 from celery.bin.celery import (
@@ -29,8 +28,10 @@ from celery.bin.celery import (
     _RemoteControl,
     command,
 )
+from celery.five import WhateverIO
+from celery.platforms import EX_FAILURE, EX_USAGE, EX_OK
 
-from celery.tests.case import AppCase, Mock, WhateverIO, patch
+from celery.tests.case import AppCase, Mock, patch
 
 
 class test__main__(AppCase):

+ 2 - 1
celery/tests/bin/test_celeryevdump.py

@@ -7,8 +7,9 @@ from celery.events.dumper import (
     Dumper,
     evdump,
 )
+from celery.five import WhateverIO
 
-from celery.tests.case import AppCase, Mock, WhateverIO, patch
+from celery.tests.case import AppCase, Mock, patch
 
 
 class test_Dumper(AppCase):

+ 2 - 6
celery/tests/bin/test_events.py

@@ -2,7 +2,7 @@ from __future__ import absolute_import, unicode_literals
 
 from celery.bin import events
 
-from celery.tests.case import AppCase, SkipTest, patch, _old_patch
+from celery.tests.case import AppCase, patch, _old_patch, skip_unless_module
 
 
 class MockCommand(object):
@@ -29,12 +29,8 @@ class test_events(AppCase):
         self.assertEqual(self.ev.run(dump=True), 'me dumper, you?')
         self.assertIn('celery events:dump', proctitle.last[0])
 
+    @skip_unless_module('curses')
     def test_run_top(self):
-        try:
-            import curses  # noqa
-        except (ImportError, OSError):
-            raise SkipTest('curses monitor requires curses')
-
         @_old_patch('celery.events.cursesmon', 'evtop',
                     lambda **kw: 'me top, you?')
         @_old_patch('celery.bin.events', 'set_process_title', proctitle)

+ 3 - 3
celery/tests/bin/test_multi.py

@@ -15,8 +15,9 @@ from celery.bin.multi import (
     multi_args,
     __doc__ as doc,
 )
+from celery.five import WhateverIO
 
-from celery.tests.case import AppCase, Mock, WhateverIO, SkipTest, patch
+from celery.tests.case import AppCase, Mock, patch, skip_unless_symbol
 
 
 class test_functions(AppCase):
@@ -264,9 +265,8 @@ class test_MultiTool(AppCase):
 
         )
 
+    @skip_unless_symbol('signal.SIGKILL')
     def test_kill(self):
-        if not hasattr(signal, 'SIGKILL'):
-            raise SkipTest('SIGKILL not supported by this platform')
         self.t.getpids = Mock()
         self.t.getpids.return_value = [
             ('a', None, 10),

+ 35 - 61
celery/tests/bin/test_worker.py

@@ -21,18 +21,19 @@ from celery.worker import state
 from celery.tests.case import (
     AppCase,
     Mock,
-    SkipTest,
-    disable_stdouts,
+    override_stdouts,
     patch,
-    skip_if_pypy,
     skip_if_jython,
+    skip_if_pypy,
+    skip_if_win32,
+    skip_unless_module,
+    skip_unless_symbol,
 )
 
 
 class WorkerAppCase(AppCase):
 
-    def tearDown(self):
-        super(WorkerAppCase, self).tearDown()
+    def teardown(self):
         trace.reset_worker_optimizations()
 
 
@@ -46,13 +47,13 @@ class Worker(cd.Worker):
 class test_Worker(WorkerAppCase):
     Worker = Worker
 
-    @disable_stdouts
+    @override_stdouts
     def test_queues_string(self):
         w = self.app.Worker()
         w.setup_queues('foo,bar,baz')
         self.assertIn('foo', self.app.amqp.queues)
 
-    @disable_stdouts
+    @override_stdouts
     def test_cpu_count(self):
         with patch('celery.worker.cpu_count') as cpu_count:
             cpu_count.side_effect = NotImplementedError()
@@ -61,7 +62,7 @@ class test_Worker(WorkerAppCase):
         w = self.app.Worker(concurrency=5)
         self.assertEqual(w.concurrency, 5)
 
-    @disable_stdouts
+    @override_stdouts
     def test_windows_B_option(self):
         self.app.IS_WINDOWS = True
         with self.assertRaises(SystemExit):
@@ -93,7 +94,7 @@ class test_Worker(WorkerAppCase):
                 x.maybe_detach(['--detach'])
             self.assertTrue(detached.called)
 
-    @disable_stdouts
+    @override_stdouts
     def test_invalid_loglevel_gives_error(self):
         x = worker(app=self.app)
         with self.assertRaises(SystemExit):
@@ -117,12 +118,12 @@ class test_Worker(WorkerAppCase):
         worker.loglevel = logging.INFO
         self.assertTrue(worker.extra_info())
 
-    @disable_stdouts
+    @override_stdouts
     def test_loglevel_string(self):
         worker = self.Worker(app=self.app, loglevel='INFO')
         self.assertEqual(worker.loglevel, logging.INFO)
 
-    @disable_stdouts
+    @override_stdouts
     def test_run_worker(self):
         handlers = {}
 
@@ -150,7 +151,7 @@ class test_Worker(WorkerAppCase):
         finally:
             platforms.signals = p
 
-    @disable_stdouts
+    @override_stdouts
     def test_startup_info(self):
         worker = self.Worker(app=self.app)
         worker.on_start()
@@ -188,18 +189,18 @@ class test_Worker(WorkerAppCase):
         finally:
             cd.ARTLINES = prev
 
-    @disable_stdouts
+    @override_stdouts
     def test_run(self):
         self.Worker(app=self.app).on_start()
         self.Worker(app=self.app, purge=True).on_start()
         worker = self.Worker(app=self.app)
         worker.on_start()
 
-    @disable_stdouts
+    @override_stdouts
     def test_purge_messages(self):
         self.Worker(app=self.app).purge_messages()
 
-    @disable_stdouts
+    @override_stdouts
     def test_init_queues(self):
         app = self.app
         c = app.conf
@@ -230,7 +231,7 @@ class test_Worker(WorkerAppCase):
             app.amqp.queues['image'],
         )
 
-    @disable_stdouts
+    @override_stdouts
     def test_autoscale_argument(self):
         worker1 = self.Worker(app=self.app, autoscale='10,3')
         self.assertListEqual(worker1.autoscale, [10, 3])
@@ -246,20 +247,17 @@ class test_Worker(WorkerAppCase):
         self.assertListEqual(worker2.include, ['os', 'sys'])
         self.Worker(app=self.app, include=['os', 'sys'])
 
-    @disable_stdouts
+    @override_stdouts
     def test_unknown_loglevel(self):
         with self.assertRaises(SystemExit):
             worker(app=self.app).run(loglevel='ALIEN')
         worker1 = self.Worker(app=self.app, loglevel=0xFFFF)
         self.assertEqual(worker1.loglevel, 0xFFFF)
 
-    @disable_stdouts
+    @skip_if_win32
+    @override_stdouts
     @patch('os._exit')
     def test_warns_if_running_as_privileged_user(self, _exit):
-        app = self.app
-        if app.IS_WINDOWS:
-            raise SkipTest('Not applicable on Windows')
-
         with patch('os.getuid') as getuid:
             getuid.return_value = 0
             self.app.conf.accept_content = ['pickle']
@@ -283,13 +281,13 @@ class test_Worker(WorkerAppCase):
                 worker = self.Worker(app=self.app)
                 worker.on_start()
 
-    @disable_stdouts
+    @override_stdouts
     def test_redirect_stdouts(self):
         self.Worker(app=self.app, redirect_stdouts=False)
         with self.assertRaises(AttributeError):
             sys.stdout.logger
 
-    @disable_stdouts
+    @override_stdouts
     def test_on_start_custom_logging(self):
         self.app.log.redirect_stdouts = Mock()
         worker = self.Worker(app=self.app, redirect_stoutds=True)
@@ -308,7 +306,7 @@ class test_Worker(WorkerAppCase):
         finally:
             self.app.log.setup = prev
 
-    @disable_stdouts
+    @override_stdouts
     def test_startup_info_pool_is_str(self):
         worker = self.Worker(app=self.app, redirect_stdouts=False)
         worker.pool_cls = 'foo'
@@ -331,7 +329,7 @@ class test_Worker(WorkerAppCase):
         finally:
             signals.setup_logging.disconnect(on_logging_setup)
 
-    @disable_stdouts
+    @override_stdouts
     def test_platform_tweaks_osx(self):
 
         class OSXWorker(Worker):
@@ -359,7 +357,7 @@ class test_Worker(WorkerAppCase):
         finally:
             cd.install_HUP_not_supported_handler = prev
 
-    @disable_stdouts
+    @override_stdouts
     def test_general_platform_tweaks(self):
 
         restart_worker_handler_installed = [False]
@@ -380,7 +378,7 @@ class test_Worker(WorkerAppCase):
         finally:
             cd.install_worker_restart_handler = prev
 
-    @disable_stdouts
+    @override_stdouts
     def test_on_consumer_ready(self):
         worker_ready_sent = [False]
 
@@ -392,17 +390,14 @@ class test_Worker(WorkerAppCase):
         self.assertTrue(worker_ready_sent[0])
 
 
+@override_stdouts
 class test_funs(WorkerAppCase):
 
     def test_active_thread_count(self):
         self.assertTrue(cd.active_thread_count())
 
-    @disable_stdouts
+    @skip_unless_module('setproctitle')
     def test_set_process_status(self):
-        try:
-            __import__('setproctitle')
-        except ImportError:
-            raise SkipTest('setproctitle not installed')
         worker = Worker(app=self.app, hostname='xyzza')
         prev1, sys.argv = sys.argv, ['Arg0']
         try:
@@ -422,7 +417,6 @@ class test_funs(WorkerAppCase):
         finally:
             sys.argv = prev1
 
-    @disable_stdouts
     def test_parse_options(self):
         cmd = worker()
         cmd.app = self.app
@@ -431,7 +425,6 @@ class test_funs(WorkerAppCase):
         self.assertEqual(opts.concurrency, 512)
         self.assertEqual(opts.heartbeat_interval, 10)
 
-    @disable_stdouts
     def test_main(self):
         p, cd.Worker = cd.Worker, Worker
         s, sys.argv = sys.argv, ['worker', '--discard']
@@ -442,6 +435,7 @@ class test_funs(WorkerAppCase):
             sys.argv = s
 
 
+@override_stdouts
 class test_signal_handlers(WorkerAppCase):
 
     class _Worker(object):
@@ -468,7 +462,6 @@ class test_signal_handlers(WorkerAppCase):
         finally:
             platforms.signals = p
 
-    @disable_stdouts
     def test_worker_int_handler(self):
         worker = self._Worker()
         handlers = self.psig(cd.install_worker_int_handler, worker)
@@ -511,12 +504,8 @@ class test_signal_handlers(WorkerAppCase):
             with self.assertRaises(WorkerTerminate):
                 next_handlers['SIGINT']('SIGINT', object())
 
-    @disable_stdouts
+    @skip_unless_module('multiprocessing')
     def test_worker_int_handler_only_stop_MainProcess(self):
-        try:
-            import _multiprocessing  # noqa
-        except ImportError:
-            raise SkipTest('only relevant for multiprocessing')
         process = current_process()
         name, process.name = process.name, 'OtherProcess'
         with patch('celery.apps.worker.active_thread_count') as c:
@@ -541,18 +530,13 @@ class test_signal_handlers(WorkerAppCase):
                 process.name = name
                 state.should_stop = None
 
-    @disable_stdouts
     def test_install_HUP_not_supported_handler(self):
         worker = self._Worker()
         handlers = self.psig(cd.install_HUP_not_supported_handler, worker)
         handlers['SIGHUP']('SIGHUP', object())
 
-    @disable_stdouts
+    @skip_unless_module('multiprocessing')
     def test_worker_term_hard_handler_only_stop_MainProcess(self):
-        try:
-            import _multiprocessing  # noqa
-        except ImportError:
-            raise SkipTest('only relevant for multiprocessing')
         process = current_process()
         name, process.name = process.name, 'OtherProcess'
         try:
@@ -579,7 +563,6 @@ class test_signal_handlers(WorkerAppCase):
         finally:
             process.name = name
 
-    @disable_stdouts
     def test_worker_term_handler_when_threads(self):
         with patch('celery.apps.worker.active_thread_count') as c:
             c.return_value = 3
@@ -591,7 +574,6 @@ class test_signal_handlers(WorkerAppCase):
             finally:
                 state.should_stop = None
 
-    @disable_stdouts
     def test_worker_term_handler_when_single_thread(self):
         with patch('celery.apps.worker.active_thread_count') as c:
             c.return_value = 1
@@ -604,19 +586,15 @@ class test_signal_handlers(WorkerAppCase):
                 state.should_stop = None
 
     @patch('sys.__stderr__')
-    @skip_if_pypy
-    @skip_if_jython
+    @skip_if_pypy()
+    @skip_if_jython()
     def test_worker_cry_handler(self, stderr):
         handlers = self.psig(cd.install_cry_handler)
         self.assertIsNone(handlers['SIGUSR1']('SIGUSR1', object()))
         self.assertTrue(stderr.write.called)
 
-    @disable_stdouts
+    @skip_unless_module('multiprocessing')
     def test_worker_term_handler_only_stop_MainProcess(self):
-        try:
-            import _multiprocessing  # noqa
-        except ImportError:
-            raise SkipTest('only relevant for multiprocessing')
         process = current_process()
         name, process.name = process.name, 'OtherProcess'
         try:
@@ -636,13 +614,11 @@ class test_signal_handlers(WorkerAppCase):
             process.name = name
             state.should_stop = None
 
-    @disable_stdouts
+    @skip_unless_symbol('os.execv')
     @patch('celery.platforms.close_open_fds')
     @patch('atexit.register')
     @patch('os.close')
     def test_worker_restart_handler(self, _close, register, close_open):
-        if getattr(os, 'execv', None) is None:
-            raise SkipTest('platform does not have excv')
         argv = []
 
         def _execv(*args):
@@ -662,7 +638,6 @@ class test_signal_handlers(WorkerAppCase):
             os.execv = execv
             state.should_stop = None
 
-    @disable_stdouts
     def test_worker_term_hard_handler_when_threaded(self):
         with patch('celery.apps.worker.active_thread_count') as c:
             c.return_value = 3
@@ -674,7 +649,6 @@ class test_signal_handlers(WorkerAppCase):
             finally:
                 state.should_terminate = None
 
-    @disable_stdouts
     def test_worker_term_hard_handler_when_single_threaded(self):
         with patch('celery.apps.worker.active_thread_count') as c:
             c.return_value = 1

+ 33 - 682
celery/tests/case.py

@@ -1,37 +1,18 @@
 from __future__ import absolute_import, unicode_literals
 
-try:
-    import unittest  # noqa
-    unittest.skip
-    from unittest.util import safe_repr, unorderable_list_difference
-except AttributeError:
-    import unittest2 as unittest  # noqa
-    from unittest2.util import safe_repr, unorderable_list_difference  # noqa
-
 import importlib
 import inspect
 import logging
 import numbers
 import os
-import platform
-import re
 import sys
 import threading
-import time
-import types
-import warnings
 
 from contextlib import contextmanager
 from copy import deepcopy
 from datetime import datetime, timedelta
 from functools import partial, wraps
-from types import ModuleType
 
-try:
-    from unittest import mock
-except ImportError:
-    import mock  # noqa
-from nose import SkipTest
 from kombu import Queue
 from kombu.utils import symbol_by_name
 
@@ -39,31 +20,15 @@ from celery import Celery
 from celery.app import current_app
 from celery.backends.cache import CacheBackend, DummyClient
 from celery.exceptions import CDeprecationWarning, CPendingDeprecationWarning
-from celery.five import (
-    WhateverIO, builtins, items, reraise,
-    string_t, values, open_fqdn, module_name_t,
-)
-from celery.utils.functional import noop
 from celery.utils.imports import qualname
 
-__all__ = [
-    'Case', 'AppCase', 'Mock', 'MagicMock', 'ANY', 'TaskMessage',
-    'patch', 'call', 'sentinel', 'skip_unless_module',
-    'wrap_logger', 'with_environ', 'sleepdeprived',
-    'skip_if_environ', 'todo', 'skip', 'skip_if',
-    'skip_unless', 'mask_modules', 'override_stdouts', 'mock_module',
-    'replace_module_value', 'sys_platform', 'reset_modules',
-    'patch_modules', 'mock_context', 'mock_open',
-    'assert_signal_called', 'skip_if_pypy',
-    'skip_if_jython', 'task_message_from_sig', 'restore_logging',
-]
-patch = mock.patch
-call = mock.call
-sentinel = mock.sentinel
-MagicMock = mock.MagicMock
-ANY = mock.ANY
+from ._case import *  # noqa
+from ._case import __all__ as _case_all, Case as _Case, decorator
 
-PY3 = sys.version_info[0] == 3
+__all__ = _case_all + [
+    'AppCase', 'TaskMessage', 'TaskMessage1',
+    'depends_on_current_app', 'assert_signal_called', 'task_message_from_sig',
+]
 
 CASE_REDEFINES_SETUP = """\
 {name} (subclass of AppCase) redefines private "setUp", should be: "setup"\
@@ -110,6 +75,11 @@ CELERY_TEST_CONFIG = {
 }
 
 
+class Case(_Case):
+    DeprecationWarning = CDeprecationWarning
+    PendingDeprecationWarning = CPendingDeprecationWarning
+
+
 class Trap(object):
 
     def __getattr__(self, name):
@@ -133,299 +103,10 @@ def UnitApp(name=None, set_as_current=False, log=UnitLogging,
     return app
 
 
-class Mock(mock.Mock):
-
-    def __init__(self, *args, **kwargs):
-        attrs = kwargs.pop('attrs', None) or {}
-        super(Mock, self).__init__(*args, **kwargs)
-        for attr_name, attr_value in items(attrs):
-            setattr(self, attr_name, attr_value)
-
-
-class _ContextMock(Mock):
-    """Dummy class implementing __enter__ and __exit__
-    as the :keyword:`with` statement requires these to be implemented
-    in the class, not just the instance."""
-
-    def __enter__(self):
-        return self
-
-    def __exit__(self, *exc_info):
-        pass
-
-
-def ContextMock(*args, **kwargs):
-    obj = _ContextMock(*args, **kwargs)
-    obj.attach_mock(_ContextMock(), '__enter__')
-    obj.attach_mock(_ContextMock(), '__exit__')
-    obj.__enter__.return_value = obj
-    # if __exit__ return a value the exception is ignored,
-    # so it must return None here.
-    obj.__exit__.return_value = None
-    return obj
-
-
-def _bind(f, o):
-    @wraps(f)
-    def bound_meth(*fargs, **fkwargs):
-        return f(o, *fargs, **fkwargs)
-    return bound_meth
-
-
-if PY3:  # pragma: no cover
-    def _get_class_fun(meth):
-        return meth
-else:
-    def _get_class_fun(meth):
-        return meth.__func__
-
-
-class MockCallbacks(object):
-
-    def __new__(cls, *args, **kwargs):
-        r = Mock(name=cls.__name__)
-        _get_class_fun(cls.__init__)(r, *args, **kwargs)
-        for key, value in items(vars(cls)):
-            if key not in ('__dict__', '__weakref__', '__new__', '__init__'):
-                if inspect.ismethod(value) or inspect.isfunction(value):
-                    r.__getattr__(key).side_effect = _bind(value, r)
-                else:
-                    r.__setattr__(key, value)
-        return r
-
-
-def skip_unless_module(module):
-
-    def _inner(fun):
-
-        @wraps(fun)
-        def __inner(*args, **kwargs):
-            try:
-                importlib.import_module(module)
-            except ImportError:
-                raise SkipTest('Does not have %s' % (module,))
-
-            return fun(*args, **kwargs)
-
-        return __inner
-    return _inner
-
-
-# -- adds assertWarns from recent unittest2, not in Python 2.7.
-
-class _AssertRaisesBaseContext(object):
-
-    def __init__(self, expected, test_case, callable_obj=None,
-                 expected_regex=None):
-        self.expected = expected
-        self.failureException = test_case.failureException
-        self.obj_name = None
-        if isinstance(expected_regex, string_t):
-            expected_regex = re.compile(expected_regex)
-        self.expected_regex = expected_regex
-
-
-def _is_magic_module(m):
-    # some libraries create custom module types that are lazily
-    # lodaded, e.g. Django installs some modules in sys.modules that
-    # will load _tkinter and other shit when touched.
-
-    # pyflakes refuses to accept 'noqa' for this isinstance.
-    cls, modtype = type(m), types.ModuleType
-    try:
-        variables = vars(cls)
-    except TypeError:
-        return True
-    else:
-        return (cls is not modtype and (
-            '__getattr__' in variables or
-            '__getattribute__' in variables))
-
-
-class _AssertWarnsContext(_AssertRaisesBaseContext):
-    """A context manager used to implement TestCase.assertWarns* methods."""
-
-    def __enter__(self):
-        # The __warningregistry__'s need to be in a pristine state for tests
-        # to work properly.
-        warnings.resetwarnings()
-        for v in list(values(sys.modules)):
-            # do not evaluate Django moved modules and other lazily
-            # initialized modules.
-            if v and not _is_magic_module(v):
-                # use raw __getattribute__ to protect even better from
-                # lazily loaded modules
-                try:
-                    object.__getattribute__(v, '__warningregistry__')
-                except AttributeError:
-                    pass
-                else:
-                    object.__setattr__(v, '__warningregistry__', {})
-        self.warnings_manager = warnings.catch_warnings(record=True)
-        self.warnings = self.warnings_manager.__enter__()
-        warnings.simplefilter('always', self.expected)
-        return self
-
-    def __exit__(self, exc_type, exc_value, tb):
-        self.warnings_manager.__exit__(exc_type, exc_value, tb)
-        if exc_type is not None:
-            # let unexpected exceptions pass through
-            return
-        try:
-            exc_name = self.expected.__name__
-        except AttributeError:
-            exc_name = str(self.expected)
-        first_matching = None
-        for m in self.warnings:
-            w = m.message
-            if not isinstance(w, self.expected):
-                continue
-            if first_matching is None:
-                first_matching = w
-            if (self.expected_regex is not None and
-                    not self.expected_regex.search(str(w))):
-                continue
-            # store warning for later retrieval
-            self.warning = w
-            self.filename = m.filename
-            self.lineno = m.lineno
-            return
-        # Now we simply try to choose a helpful failure message
-        if first_matching is not None:
-            raise self.failureException(
-                '%r does not match %r' % (
-                    self.expected_regex.pattern, str(first_matching)))
-        if self.obj_name:
-            raise self.failureException(
-                '%s not triggered by %s' % (exc_name, self.obj_name))
-        else:
-            raise self.failureException('%s not triggered' % exc_name)
-
-
 def alive_threads():
     return [thread for thread in threading.enumerate() if thread.is_alive()]
 
 
-class Case(unittest.TestCase):
-
-    def patch(self, *path, **options):
-        manager = patch('.'.join(path), **options)
-        patched = manager.start()
-        self.addCleanup(manager.stop)
-        return patched
-
-    def mock_modules(self, *mods):
-        modules = []
-        for mod in mods:
-            mod = mod.split('.')
-            modules.extend(reversed([
-                '.'.join(mod[:-i] if i else mod) for i in range(len(mod))
-            ]))
-        modules = sorted(set(modules))
-        return self.wrap_context(mock_module(*modules))
-
-    def on_nth_call_do(self, mock, side_effect, n=1):
-
-        def on_call(*args, **kwargs):
-            if mock.call_count >= n:
-                mock.side_effect = side_effect
-            return mock.return_value
-        mock.side_effect = on_call
-        return mock
-
-    def on_nth_call_return(self, mock, retval, n=1):
-
-        def on_call(*args, **kwargs):
-            if mock.call_count >= n:
-                mock.return_value = retval
-            return mock.return_value
-        mock.side_effect = on_call
-        return mock
-
-    def mask_modules(self, *modules):
-        self.wrap_context(mask_modules(*modules))
-
-    def wrap_context(self, context):
-        ret = context.__enter__()
-        self.addCleanup(partial(context.__exit__, None, None, None))
-        return ret
-
-    def mock_environ(self, env_name, env_value):
-        return self.wrap_context(mock_environ(env_name, env_value))
-
-    def assertWarns(self, expected_warning):
-        return _AssertWarnsContext(expected_warning, self, None)
-
-    def assertWarnsRegex(self, expected_warning, expected_regex):
-        return _AssertWarnsContext(expected_warning, self,
-                                   None, expected_regex)
-
-    @contextmanager
-    def assertDeprecated(self):
-        with self.assertWarnsRegex(CDeprecationWarning,
-                                   r'scheduled for removal'):
-            yield
-
-    @contextmanager
-    def assertPendingDeprecation(self):
-        with self.assertWarnsRegex(CPendingDeprecationWarning,
-                                   r'scheduled for deprecation'):
-            yield
-
-    def assertDictContainsSubset(self, expected, actual, msg=None):
-        missing, mismatched = [], []
-
-        for key, value in items(expected):
-            if key not in actual:
-                missing.append(key)
-            elif value != actual[key]:
-                mismatched.append('%s, expected: %s, actual: %s' % (
-                    safe_repr(key), safe_repr(value),
-                    safe_repr(actual[key])))
-
-        if not (missing or mismatched):
-            return
-
-        standard_msg = ''
-        if missing:
-            standard_msg = 'Missing: %s' % ','.join(map(safe_repr, missing))
-
-        if mismatched:
-            if standard_msg:
-                standard_msg += '; '
-            standard_msg += 'Mismatched values: %s' % (
-                ','.join(mismatched))
-
-        self.fail(self._formatMessage(msg, standard_msg))
-
-    def assertItemsEqual(self, expected_seq, actual_seq, msg=None):
-        missing = unexpected = None
-        try:
-            expected = sorted(expected_seq)
-            actual = sorted(actual_seq)
-        except TypeError:
-            # Unsortable items (example: set(), complex(), ...)
-            expected = list(expected_seq)
-            actual = list(actual_seq)
-            missing, unexpected = unorderable_list_difference(
-                expected, actual)
-        else:
-            return self.assertSequenceEqual(expected, actual, msg=msg)
-
-        errors = []
-        if missing:
-            errors.append(
-                'Expected, but missing:\n    %s' % (safe_repr(missing),)
-            )
-        if unexpected:
-            errors.append(
-                'Unexpected, but present:\n    %s' % (safe_repr(unexpected),)
-            )
-        if errors:
-            standardMsg = '\n'.join(errors)
-            self.fail(self._formatMessage(msg, standardMsg))
-
-
 def depends_on_current_app(fun):
     if inspect.isclass(fun):
         fun.contained = False
@@ -443,11 +124,13 @@ class AppCase(Case):
 
     def __init__(self, *args, **kwargs):
         super(AppCase, self).__init__(*args, **kwargs)
-        if self.__class__.__dict__.get('setUp'):
+        setUp = self.__class__.__dict__.get('setUp')
+        tearDown = self.__class__.__dict__.get('tearDown')
+        if setUp and not hasattr(setUp, '__wrapped__'):
             raise RuntimeError(
                 CASE_REDEFINES_SETUP.format(name=qualname(self)),
             )
-        if self.__class__.__dict__.get('tearDown'):
+        if tearDown and not hasattr(tearDown, '__wrapped__'):
             raise RuntimeError(
                 CASE_REDEFINES_TEARDOWN.format(name=qualname(self)),
             )
@@ -552,6 +235,9 @@ class AppCase(Case):
         if root.handlers != self.__roothandlers:
             raise RuntimeError(CASE_LOG_HANDLER_EFFECT.format(this))
 
+    def assert_signal_called(self, signal, **expected):
+        return assert_signal_called(signal, **expected)
+
     def setup(self):
         pass
 
@@ -559,324 +245,7 @@ class AppCase(Case):
         pass
 
 
-def get_handlers(logger):
-    return [
-        h for h in logger.handlers
-        if not isinstance(h, logging.NullHandler)
-    ]
-
-
-@contextmanager
-def wrap_logger(logger, loglevel=logging.ERROR):
-    old_handlers = get_handlers(logger)
-    sio = WhateverIO()
-    siohandler = logging.StreamHandler(sio)
-    logger.handlers = [siohandler]
-
-    try:
-        yield sio
-    finally:
-        logger.handlers = old_handlers
-
-
-@contextmanager
-def mock_environ(env_name, env_value):
-    sentinel = object()
-    prev_val = os.environ.get(env_name, sentinel)
-    os.environ[env_name] = env_value
-    try:
-        yield env_value
-    finally:
-        if prev_val is sentinel:
-            os.environ.pop(env_name, None)
-        else:
-            os.environ[env_name] = prev_val
-
-
-def with_environ(env_name, env_value):
-
-    def _envpatched(fun):
-
-        @wraps(fun)
-        def _patch_environ(*args, **kwargs):
-            with mock_environ(env_name, env_value):
-                return fun(*args, **kwargs)
-        return _patch_environ
-    return _envpatched
-
-
-def sleepdeprived(module=time):
-
-    def _sleepdeprived(fun):
-
-        @wraps(fun)
-        def __sleepdeprived(*args, **kwargs):
-            old_sleep = module.sleep
-            module.sleep = noop
-            try:
-                return fun(*args, **kwargs)
-            finally:
-                module.sleep = old_sleep
-
-        return __sleepdeprived
-
-    return _sleepdeprived
-
-
-def skip_if_environ(env_var_name):
-
-    def _wrap_test(fun):
-
-        @wraps(fun)
-        def _skips_if_environ(*args, **kwargs):
-            if os.environ.get(env_var_name):
-                raise SkipTest('SKIP %s: %s set\n' % (
-                    fun.__name__, env_var_name))
-            return fun(*args, **kwargs)
-
-        return _skips_if_environ
-
-    return _wrap_test
-
-
-def _skip_test(reason, sign):
-
-    def _wrap_test(fun):
-
-        @wraps(fun)
-        def _skipped_test(*args, **kwargs):
-            raise SkipTest('%s: %s' % (sign, reason))
-
-        return _skipped_test
-    return _wrap_test
-
-
-def todo(reason):
-    """TODO test decorator."""
-    return _skip_test(reason, 'TODO')
-
-
-def skip(reason):
-    """Skip test decorator."""
-    return _skip_test(reason, 'SKIP')
-
-
-def skip_if(predicate, reason):
-    """Skip test if predicate is :const:`True`."""
-
-    def _inner(fun):
-        return predicate and skip(reason)(fun) or fun
-
-    return _inner
-
-
-def skip_unless(predicate, reason):
-    """Skip test if predicate is :const:`False`."""
-    return skip_if(not predicate, reason)
-
-
-# Taken from
-# http://bitbucket.org/runeh/snippets/src/tip/missing_modules.py
-@contextmanager
-def mask_modules(*modnames):
-    """Ban some modules from being importable inside the context
-
-    For example:
-
-        >>> with mask_modules('sys'):
-        ...     try:
-        ...         import sys
-        ...     except ImportError:
-        ...         print('sys not found')
-        sys not found
-
-        >>> import sys  # noqa
-        >>> sys.version
-        (2, 5, 2, 'final', 0)
-
-    """
-
-    realimport = builtins.__import__
-
-    def myimp(name, *args, **kwargs):
-        if name in modnames:
-            raise ImportError('No module named %s' % name)
-        else:
-            return realimport(name, *args, **kwargs)
-
-    builtins.__import__ = myimp
-    try:
-        yield True
-    finally:
-        builtins.__import__ = realimport
-
-
-@contextmanager
-def override_stdouts():
-    """Override `sys.stdout` and `sys.stderr` with `WhateverIO`."""
-    prev_out, prev_err = sys.stdout, sys.stderr
-    prev_rout, prev_rerr = sys.__stdout__, sys.__stderr__
-    mystdout, mystderr = WhateverIO(), WhateverIO()
-    sys.stdout = sys.__stdout__ = mystdout
-    sys.stderr = sys.__stderr__ = mystderr
-
-    try:
-        yield mystdout, mystderr
-    finally:
-        sys.stdout = prev_out
-        sys.stderr = prev_err
-        sys.__stdout__ = prev_rout
-        sys.__stderr__ = prev_rerr
-
-
-def disable_stdouts(fun):
-
-    @wraps(fun)
-    def disable(*args, **kwargs):
-        with override_stdouts():
-            return fun(*args, **kwargs)
-    return disable
-
-
-def _old_patch(module, name, mocked):
-    module = importlib.import_module(module)
-
-    def _patch(fun):
-
-        @wraps(fun)
-        def __patched(*args, **kwargs):
-            prev = getattr(module, name)
-            setattr(module, name, mocked)
-            try:
-                return fun(*args, **kwargs)
-            finally:
-                setattr(module, name, prev)
-        return __patched
-    return _patch
-
-
-@contextmanager
-def replace_module_value(module, name, value=None):
-    has_prev = hasattr(module, name)
-    prev = getattr(module, name, None)
-    if value:
-        setattr(module, name, value)
-    else:
-        try:
-            delattr(module, name)
-        except AttributeError:
-            pass
-    try:
-        yield
-    finally:
-        if prev is not None:
-            setattr(module, name, prev)
-        if not has_prev:
-            try:
-                delattr(module, name)
-            except AttributeError:
-                pass
-pypy_version = partial(
-    replace_module_value, sys, 'pypy_version_info',
-)
-platform_pyimp = partial(
-    replace_module_value, platform, 'python_implementation',
-)
-
-
-@contextmanager
-def sys_platform(value):
-    prev, sys.platform = sys.platform, value
-    try:
-        yield
-    finally:
-        sys.platform = prev
-
-
-@contextmanager
-def reset_modules(*modules):
-    prev = {k: sys.modules.pop(k) for k in modules if k in sys.modules}
-    try:
-        yield
-    finally:
-        sys.modules.update(prev)
-
-
-@contextmanager
-def patch_modules(*modules):
-    prev = {}
-    for mod in modules:
-        prev[mod] = sys.modules.get(mod)
-        sys.modules[mod] = ModuleType(module_name_t(mod))
-    try:
-        yield
-    finally:
-        for name, mod in items(prev):
-            if mod is None:
-                sys.modules.pop(name, None)
-            else:
-                sys.modules[name] = mod
-
-
-@contextmanager
-def mock_module(*names):
-    prev = {}
-
-    class MockModule(ModuleType):
-
-        def __getattr__(self, attr):
-            setattr(self, attr, Mock())
-            return ModuleType.__getattribute__(self, attr)
-
-    mods = []
-    for name in names:
-        try:
-            prev[name] = sys.modules[name]
-        except KeyError:
-            pass
-        mod = sys.modules[name] = MockModule(module_name_t(name))
-        mods.append(mod)
-    try:
-        yield mods
-    finally:
-        for name in names:
-            try:
-                sys.modules[name] = prev[name]
-            except KeyError:
-                try:
-                    del(sys.modules[name])
-                except KeyError:
-                    pass
-
-
-@contextmanager
-def mock_context(mock, typ=Mock):
-    context = mock.return_value = Mock()
-    context.__enter__ = typ()
-    context.__exit__ = typ()
-
-    def on_exit(*x):
-        if x[0]:
-            reraise(x[0], x[1], x[2])
-    context.__exit__.side_effect = on_exit
-    context.__enter__.return_value = context
-    try:
-        yield context
-    finally:
-        context.reset()
-
-
-@contextmanager
-def mock_open(typ=WhateverIO, side_effect=None):
-    with patch(open_fqdn) as open_:
-        with mock_context(open_) as context:
-            if side_effect is not None:
-                context.__enter__.side_effect = side_effect
-            val = context.__enter__.return_value = typ()
-            val.__exit__ = Mock()
-            yield val
-
-
+@decorator
 @contextmanager
 def assert_signal_called(signal, **expected):
     handler = Mock()
@@ -889,26 +258,6 @@ def assert_signal_called(signal, **expected):
     handler.assert_called_with(signal=signal, **expected)
 
 
-def skip_if_pypy(fun):
-
-    @wraps(fun)
-    def _inner(*args, **kwargs):
-        if getattr(sys, 'pypy_version_info', None):
-            raise SkipTest('does not work on PyPy')
-        return fun(*args, **kwargs)
-    return _inner
-
-
-def skip_if_jython(fun):
-
-    @wraps(fun)
-    def _inner(*args, **kwargs):
-        if sys.platform.startswith('java'):
-            raise SkipTest('does not work on Jython')
-        return fun(*args, **kwargs)
-    return _inner
-
-
 def TaskMessage(name, id=None, args=(), kwargs={}, callbacks=None,
                 errbacks=None, chain=None, shadow=None, utc=None, **options):
     from celery import uuid
@@ -979,16 +328,18 @@ def task_message_from_sig(app, sig, utc=True, TaskMessage=TaskMessage):
     )
 
 
-@contextmanager
-def restore_logging():
-    outs = sys.stdout, sys.stderr, sys.__stdout__, sys.__stderr__
-    root = logging.getLogger()
-    level = root.level
-    handlers = root.handlers
+def _old_patch(module, name, mocked):
+    module = importlib.import_module(module)
 
-    try:
-        yield
-    finally:
-        sys.stdout, sys.stderr, sys.__stdout__, sys.__stderr__ = outs
-        root.level = level
-        root.handlers[:] = handlers
+    def _patch(fun):
+
+        @wraps(fun)
+        def __patched(*args, **kwargs):
+            prev = getattr(module, name)
+            setattr(module, name, mocked)
+            try:
+                return fun(*args, **kwargs)
+            finally:
+                setattr(module, name, prev)
+        return __patched
+    return _patch

+ 1 - 2
celery/tests/concurrency/test_eventlet.py

@@ -12,13 +12,12 @@ from celery.concurrency.eventlet import (
 from celery.tests.case import AppCase, Mock, patch, skip_if_pypy
 
 
+@skip_if_pypy()
 class EventletCase(AppCase):
 
-    @skip_if_pypy
     def setup(self):
         self.mock_modules(*eventlet_modules)
 
-    @skip_if_pypy
     def teardown(self):
         for mod in [mod for mod in sys.modules if mod.startswith('eventlet')]:
             try:

+ 1 - 1
celery/tests/concurrency/test_gevent.py

@@ -17,9 +17,9 @@ gevent_modules = (
 )
 
 
+@skip_if_pypy()
 class GeventCase(AppCase):
 
-    @skip_if_pypy
     def setup(self):
         self.mock_modules(*gevent_modules)
 

+ 2 - 5
celery/tests/concurrency/test_pool.py

@@ -5,7 +5,7 @@ import itertools
 
 from billiard.einfo import ExceptionInfo
 
-from celery.tests.case import AppCase, SkipTest
+from celery.tests.case import AppCase, skip_unless_module
 
 
 def do_something(i):
@@ -23,13 +23,10 @@ def raise_something(i):
         return ExceptionInfo()
 
 
+@skip_unless_module('multiprocessing')
 class test_TaskPool(AppCase):
 
     def setup(self):
-        try:
-            __import__('multiprocessing')
-        except ImportError:
-            raise SkipTest('multiprocessing not supported')
         from celery.concurrency.prefork import TaskPool
         self.TaskPool = TaskPool
 

+ 7 - 16
celery/tests/concurrency/test_prefork.py

@@ -3,7 +3,6 @@ from __future__ import absolute_import, unicode_literals
 import errno
 import os
 import socket
-import sys
 
 from itertools import cycle
 
@@ -13,7 +12,9 @@ from celery.five import range
 from celery.utils.functional import noop
 from celery.utils.objects import Bunch
 
-from celery.tests.case import AppCase, Mock, SkipTest, patch, restore_logging
+from celery.tests.case import (
+    AppCase, Mock, patch, restore_logging, skip_if_win32, skip_unless_module,
+)
 
 try:
     from celery.concurrency import prefork as mp
@@ -185,21 +186,14 @@ class ExeMockTaskPool(mp.TaskPool):
     Pool = BlockingPool = ExeMockPool
 
 
+@skip_unless_module('multiprocessing')
 class PoolCase(AppCase):
-
-    def setup(self):
-        try:
-            import multiprocessing  # noqa
-        except ImportError:
-            raise SkipTest('multiprocessing not supported')
+    pass
 
 
+@skip_if_win32
 class test_AsynPool(PoolCase):
 
-    def setup(self):
-        if sys.platform == 'win32':
-            raise SkipTest('win32: skip')
-
     def test_gen_not_started(self):
 
         def gen():
@@ -303,12 +297,9 @@ class test_AsynPool(PoolCase):
         w.outq.put.assert_called_with((asynpool.WORKER_UP, (1234,)))
 
 
+@skip_if_win32
 class test_ResultHandler(PoolCase):
 
-    def setup(self):
-        if sys.platform == 'win32':
-            raise SkipTest('win32: skip')
-
     def test_process_result(self):
         x = asynpool.ResultHandler(
             Mock(), Mock(), {}, Mock(),

+ 4 - 3
celery/tests/contrib/test_rdb.py

@@ -8,7 +8,8 @@ from celery.contrib.rdb import (
     debugger,
     set_trace,
 )
-from celery.tests.case import AppCase, Mock, WhateverIO, patch, skip_if_pypy
+from celery.five import WhateverIO
+from celery.tests.case import AppCase, Mock, patch, skip_if_pypy
 
 
 class SockErr(socket.error):
@@ -31,7 +32,7 @@ class test_Rdb(AppCase):
         self.assertTrue(debugger.return_value.set_trace.called)
 
     @patch('celery.contrib.rdb.Rdb.get_avail_port')
-    @skip_if_pypy
+    @skip_if_pypy()
     def test_rdb(self, get_avail_port):
         sock = Mock()
         get_avail_port.return_value = (sock, 8000)
@@ -75,7 +76,7 @@ class test_Rdb(AppCase):
             rdb.set_quit.assert_called_with()
 
     @patch('socket.socket')
-    @skip_if_pypy
+    @skip_if_pypy()
     def test_get_avail_port(self, sock):
         out = WhateverIO()
         sock.return_value.accept.return_value = (Mock(), ['helu'])

+ 2 - 6
celery/tests/events/test_cursesmon.py

@@ -1,6 +1,6 @@
 from __future__ import absolute_import, unicode_literals
 
-from celery.tests.case import AppCase, SkipTest
+from celery.tests.case import AppCase, skip_unless_module
 
 
 class MockWindow(object):
@@ -9,14 +9,10 @@ class MockWindow(object):
         return self.y, self.x
 
 
+@skip_unless_module('curses')
 class test_CursesDisplay(AppCase):
 
     def setup(self):
-        try:
-            import curses  # noqa
-        except (ImportError, OSError):
-            raise SkipTest('curses monitor requires curses')
-
         from celery.events import cursesmon
         self.monitor = cursesmon.CursesMonitor(object(), app=self.app)
         self.win = MockWindow()

+ 2 - 2
celery/tests/events/test_state.py

@@ -19,7 +19,7 @@ from celery.events.state import (
 )
 from celery.five import range
 from celery.utils import uuid
-from celery.tests.case import AppCase, Mock, SkipTest, patch
+from celery.tests.case import AppCase, Mock, patch, todo
 
 try:
     Decimal(2.6)
@@ -374,8 +374,8 @@ class test_State(AppCase):
         self.assertEqual(now[1][0], tC)
         self.assertEqual(now[2][0], tB)
 
+    @todo(reason='not working')
     def test_task_descending_clock_ordering(self):
-        raise SkipTest('not working')
         state = State()
         r = ev_logical_clock_ordering(state)
         tA, tB, tC = r.uids

+ 3 - 7
celery/tests/security/case.py

@@ -1,12 +1,8 @@
 from __future__ import absolute_import, unicode_literals
 
-from celery.tests.case import AppCase, SkipTest
+from celery.tests.case import AppCase, skip_unless_module
 
 
+@skip_unless_module('OpenSSL.crypto', name='pyOpenSSL')
 class SecurityCase(AppCase):
-
-    def setup(self):
-        try:
-            from OpenSSL import crypto  # noqa
-        except ImportError:
-            raise SkipTest('OpenSSL.crypto not installed')
+    pass

+ 2 - 2
celery/tests/security/test_certificate.py

@@ -6,7 +6,7 @@ from celery.security.certificate import Certificate, CertStore, FSCertStore
 from . import CERT1, CERT2, KEY1
 from .case import SecurityCase
 
-from celery.tests.case import Mock, SkipTest, mock_open, patch
+from celery.tests.case import Mock, mock_open, patch, todo
 
 
 class test_Certificate(SecurityCase):
@@ -27,8 +27,8 @@ class test_Certificate(SecurityCase):
         with self.assertRaises(SecurityError):
             Certificate(KEY1)
 
+    @todo(reason='cert expired')
     def test_has_expired(self):
-        raise SkipTest('cert expired')
         self.assertFalse(Certificate(CERT1).has_expired())
 
     def test_has_expired_mock(self):

+ 3 - 11
celery/tests/utils/test_datastructures.py

@@ -1,7 +1,6 @@
 from __future__ import absolute_import, unicode_literals
 
 import pickle
-import sys
 
 from collections import Mapping
 from itertools import count
@@ -16,10 +15,10 @@ from celery.datastructures import (
     ConfigurationView,
     DependencyGraph,
 )
-from celery.five import items
+from celery.five import WhateverIO, items
 from celery.utils.objects import Bunch
 
-from celery.tests.case import Case, Mock, WhateverIO, SkipTest
+from celery.tests.case import Case, Mock, skip_if_win32
 
 
 class test_DictAttribute(Case):
@@ -168,15 +167,10 @@ class test_ExceptionInfo(Case):
             self.assertTrue(r)
 
 
+@skip_if_win32()
 class test_LimitedSet(Case):
 
-    def setUp(self):
-        if sys.platform == 'win32':
-            raise SkipTest('Not working on Windows')
-
     def test_add(self):
-        if sys.platform == 'win32':
-            raise SkipTest('Not working properly on Windows')
         s = LimitedSet(maxlen=2)
         s.add('foo')
         s.add('bar')
@@ -239,8 +233,6 @@ class test_LimitedSet(Case):
         self.assertEqual(pickle.loads(pickle.dumps(s)), s)
 
     def test_iter(self):
-        if sys.platform == 'win32':
-            raise SkipTest('Not working on Windows')
         s = LimitedSet(maxlen=3)
         items = ['foo', 'bar', 'baz', 'xaz']
         for item in items:

+ 571 - 561
celery/tests/utils/test_platforms.py

@@ -8,7 +8,7 @@ import tempfile
 
 from celery import _find_option_with_arg
 from celery import platforms
-from celery.five import open_fqdn
+from celery.five import WhateverIO
 from celery.platforms import (
     get_fdmax,
     ignore_errno,
@@ -38,9 +38,10 @@ try:
 except ImportError:  # pragma: no cover
     resource = None  # noqa
 
+from celery.tests._case import open_fqdn
 from celery.tests.case import (
-    Case, WhateverIO, Mock, SkipTest,
-    call, override_stdouts, mock_open, patch,
+    Case, Mock,
+    call, override_stdouts, mock_open, patch, skip_if_win32,
 )
 
 
@@ -59,12 +60,9 @@ class test_find_option_with_arg(Case):
         )
 
 
+@skip_if_win32()
 class test_fd_by_path(Case):
 
-    def setUp(self):
-        if sys.platform == 'win32':
-            raise SkipTest('win32: skip')
-
     def test_finds(self):
         test_file = tempfile.NamedTemporaryFile()
         try:
@@ -143,9 +141,8 @@ class test_Signals(Case):
         self.assertTrue(signals.supported('INT'))
         self.assertFalse(signals.supported('SIGIMAGINARY'))
 
+    @skip_if_win32()
     def test_reset_alarm(self):
-        if sys.platform == 'win32':
-            raise SkipTest('signal.alarm not available on Windows')
         with patch('signal.alarm') as _alarm:
             signals.reset_alarm()
             _alarm.assert_called_with(0)
@@ -189,622 +186,635 @@ class test_Signals(Case):
         signals['INT'] = lambda *a: a
 
 
-if not platforms.IS_WINDOWS:
-
-    class test_get_fdmax(Case):
-
-        @patch('resource.getrlimit')
-        def test_when_infinity(self, getrlimit):
-            with patch('os.sysconf') as sysconfig:
-                sysconfig.side_effect = KeyError()
-                getrlimit.return_value = [None, resource.RLIM_INFINITY]
-                default = object()
-                self.assertIs(get_fdmax(default), default)
-
-        @patch('resource.getrlimit')
-        def test_when_actual(self, getrlimit):
-            with patch('os.sysconf') as sysconfig:
-                sysconfig.side_effect = KeyError()
-                getrlimit.return_value = [None, 13]
-                self.assertEqual(get_fdmax(None), 13)
-
-    class test_maybe_drop_privileges(Case):
-
-        def test_on_windows(self):
-            prev, sys.platform = sys.platform, 'win32'
-            try:
-                maybe_drop_privileges()
-            finally:
-                sys.platform = prev
-
-        @patch('os.getegid')
-        @patch('os.getgid')
-        @patch('os.geteuid')
-        @patch('os.getuid')
-        @patch('celery.platforms.parse_uid')
-        @patch('celery.platforms.parse_gid')
-        @patch('pwd.getpwuid')
-        @patch('celery.platforms.setgid')
-        @patch('celery.platforms.setuid')
-        @patch('celery.platforms.initgroups')
-        def test_with_uid(self, initgroups, setuid, setgid,
-                          getpwuid, parse_gid, parse_uid, getuid, geteuid,
-                          getgid, getegid):
-            geteuid.return_value = 10
-            getuid.return_value = 10
-
-            class pw_struct(object):
-                pw_gid = 50001
-
-            def raise_on_second_call(*args, **kwargs):
-                setuid.side_effect = OSError()
-                setuid.side_effect.errno = errno.EPERM
-            setuid.side_effect = raise_on_second_call
-            getpwuid.return_value = pw_struct()
-            parse_uid.return_value = 5001
-            parse_gid.return_value = 5001
-            maybe_drop_privileges(uid='user')
-            parse_uid.assert_called_with('user')
-            getpwuid.assert_called_with(5001)
-            setgid.assert_called_with(50001)
-            initgroups.assert_called_with(5001, 50001)
-            setuid.assert_has_calls([call(5001), call(0)])
-
-            setuid.side_effect = raise_on_second_call
-
-            def to_root_on_second_call(mock, first):
-                return_value = [first]
-
-                def on_first_call(*args, **kwargs):
-                    ret, return_value[0] = return_value[0], 0
-                    return ret
-                mock.side_effect = on_first_call
-            to_root_on_second_call(geteuid, 10)
-            to_root_on_second_call(getuid, 10)
-            with self.assertRaises(AssertionError):
-                maybe_drop_privileges(uid='user')
+@skip_if_win32()
+class test_get_fdmax(Case):
 
-            getuid.return_value = getuid.side_effect = None
-            geteuid.return_value = geteuid.side_effect = None
-            getegid.return_value = 0
-            getgid.return_value = 0
-            setuid.side_effect = raise_on_second_call
-            with self.assertRaises(AssertionError):
-                maybe_drop_privileges(gid='group')
-
-            getuid.reset_mock()
-            geteuid.reset_mock()
-            setuid.reset_mock()
-            getuid.side_effect = geteuid.side_effect = None
-
-            def raise_on_second_call(*args, **kwargs):
-                setuid.side_effect = OSError()
-                setuid.side_effect.errno = errno.ENOENT
-            setuid.side_effect = raise_on_second_call
-            with self.assertRaises(OSError):
-                maybe_drop_privileges(uid='user')
-
-        @patch('celery.platforms.parse_uid')
-        @patch('celery.platforms.parse_gid')
-        @patch('celery.platforms.setgid')
-        @patch('celery.platforms.setuid')
-        @patch('celery.platforms.initgroups')
-        def test_with_guid(self, initgroups, setuid, setgid,
-                           parse_gid, parse_uid):
-
-            def raise_on_second_call(*args, **kwargs):
-                setuid.side_effect = OSError()
-                setuid.side_effect.errno = errno.EPERM
-            setuid.side_effect = raise_on_second_call
-            parse_uid.return_value = 5001
-            parse_gid.return_value = 50001
-            maybe_drop_privileges(uid='user', gid='group')
-            parse_uid.assert_called_with('user')
-            parse_gid.assert_called_with('group')
-            setgid.assert_called_with(50001)
-            initgroups.assert_called_with(5001, 50001)
-            setuid.assert_has_calls([call(5001), call(0)])
+    @patch('resource.getrlimit')
+    def test_when_infinity(self, getrlimit):
+        with patch('os.sysconf') as sysconfig:
+            sysconfig.side_effect = KeyError()
+            getrlimit.return_value = [None, resource.RLIM_INFINITY]
+            default = object()
+            self.assertIs(get_fdmax(default), default)
 
-            setuid.side_effect = None
-            with self.assertRaises(RuntimeError):
-                maybe_drop_privileges(uid='user', gid='group')
+    @patch('resource.getrlimit')
+    def test_when_actual(self, getrlimit):
+        with patch('os.sysconf') as sysconfig:
+            sysconfig.side_effect = KeyError()
+            getrlimit.return_value = [None, 13]
+            self.assertEqual(get_fdmax(None), 13)
+
+
+@skip_if_win32()
+class test_maybe_drop_privileges(Case):
+
+    def test_on_windows(self):
+        prev, sys.platform = sys.platform, 'win32'
+        try:
+            maybe_drop_privileges()
+        finally:
+            sys.platform = prev
+
+    @patch('os.getegid')
+    @patch('os.getgid')
+    @patch('os.geteuid')
+    @patch('os.getuid')
+    @patch('celery.platforms.parse_uid')
+    @patch('celery.platforms.parse_gid')
+    @patch('pwd.getpwuid')
+    @patch('celery.platforms.setgid')
+    @patch('celery.platforms.setuid')
+    @patch('celery.platforms.initgroups')
+    def test_with_uid(self, initgroups, setuid, setgid,
+                      getpwuid, parse_gid, parse_uid, getuid, geteuid,
+                      getgid, getegid):
+        geteuid.return_value = 10
+        getuid.return_value = 10
+
+        class pw_struct(object):
+            pw_gid = 50001
+
+        def raise_on_second_call(*args, **kwargs):
             setuid.side_effect = OSError()
-            setuid.side_effect.errno = errno.EINVAL
-            with self.assertRaises(OSError):
-                maybe_drop_privileges(uid='user', gid='group')
-
-        @patch('celery.platforms.setuid')
-        @patch('celery.platforms.setgid')
-        @patch('celery.platforms.parse_gid')
-        def test_only_gid(self, parse_gid, setgid, setuid):
-            parse_gid.return_value = 50001
+            setuid.side_effect.errno = errno.EPERM
+        setuid.side_effect = raise_on_second_call
+        getpwuid.return_value = pw_struct()
+        parse_uid.return_value = 5001
+        parse_gid.return_value = 5001
+        maybe_drop_privileges(uid='user')
+        parse_uid.assert_called_with('user')
+        getpwuid.assert_called_with(5001)
+        setgid.assert_called_with(50001)
+        initgroups.assert_called_with(5001, 50001)
+        setuid.assert_has_calls([call(5001), call(0)])
+
+        setuid.side_effect = raise_on_second_call
+
+        def to_root_on_second_call(mock, first):
+            return_value = [first]
+
+            def on_first_call(*args, **kwargs):
+                ret, return_value[0] = return_value[0], 0
+                return ret
+            mock.side_effect = on_first_call
+        to_root_on_second_call(geteuid, 10)
+        to_root_on_second_call(getuid, 10)
+        with self.assertRaises(AssertionError):
+            maybe_drop_privileges(uid='user')
+
+        getuid.return_value = getuid.side_effect = None
+        geteuid.return_value = geteuid.side_effect = None
+        getegid.return_value = 0
+        getgid.return_value = 0
+        setuid.side_effect = raise_on_second_call
+        with self.assertRaises(AssertionError):
             maybe_drop_privileges(gid='group')
-            parse_gid.assert_called_with('group')
-            setgid.assert_called_with(50001)
-            self.assertFalse(setuid.called)
 
-    class test_setget_uid_gid(Case):
+        getuid.reset_mock()
+        geteuid.reset_mock()
+        setuid.reset_mock()
+        getuid.side_effect = geteuid.side_effect = None
 
-        @patch('celery.platforms.parse_uid')
-        @patch('os.setuid')
-        def test_setuid(self, _setuid, parse_uid):
-            parse_uid.return_value = 5001
-            setuid('user')
-            parse_uid.assert_called_with('user')
-            _setuid.assert_called_with(5001)
+        def raise_on_second_call(*args, **kwargs):
+            setuid.side_effect = OSError()
+            setuid.side_effect.errno = errno.ENOENT
+        setuid.side_effect = raise_on_second_call
+        with self.assertRaises(OSError):
+            maybe_drop_privileges(uid='user')
 
-        @patch('celery.platforms.parse_gid')
-        @patch('os.setgid')
-        def test_setgid(self, _setgid, parse_gid):
-            parse_gid.return_value = 50001
-            setgid('group')
-            parse_gid.assert_called_with('group')
-            _setgid.assert_called_with(50001)
+    @patch('celery.platforms.parse_uid')
+    @patch('celery.platforms.parse_gid')
+    @patch('celery.platforms.setgid')
+    @patch('celery.platforms.setuid')
+    @patch('celery.platforms.initgroups')
+    def test_with_guid(self, initgroups, setuid, setgid,
+                       parse_gid, parse_uid):
 
-        def test_parse_uid_when_int(self):
-            self.assertEqual(parse_uid(5001), 5001)
+        def raise_on_second_call(*args, **kwargs):
+            setuid.side_effect = OSError()
+            setuid.side_effect.errno = errno.EPERM
+        setuid.side_effect = raise_on_second_call
+        parse_uid.return_value = 5001
+        parse_gid.return_value = 50001
+        maybe_drop_privileges(uid='user', gid='group')
+        parse_uid.assert_called_with('user')
+        parse_gid.assert_called_with('group')
+        setgid.assert_called_with(50001)
+        initgroups.assert_called_with(5001, 50001)
+        setuid.assert_has_calls([call(5001), call(0)])
+
+        setuid.side_effect = None
+        with self.assertRaises(RuntimeError):
+            maybe_drop_privileges(uid='user', gid='group')
+        setuid.side_effect = OSError()
+        setuid.side_effect.errno = errno.EINVAL
+        with self.assertRaises(OSError):
+            maybe_drop_privileges(uid='user', gid='group')
 
-        @patch('pwd.getpwnam')
-        def test_parse_uid_when_existing_name(self, getpwnam):
+    @patch('celery.platforms.setuid')
+    @patch('celery.platforms.setgid')
+    @patch('celery.platforms.parse_gid')
+    def test_only_gid(self, parse_gid, setgid, setuid):
+        parse_gid.return_value = 50001
+        maybe_drop_privileges(gid='group')
+        parse_gid.assert_called_with('group')
+        setgid.assert_called_with(50001)
+        self.assertFalse(setuid.called)
 
-            class pwent(object):
-                pw_uid = 5001
 
-            getpwnam.return_value = pwent()
-            self.assertEqual(parse_uid('user'), 5001)
+@skip_if_win32()
+class test_setget_uid_gid(Case):
 
-        @patch('pwd.getpwnam')
-        def test_parse_uid_when_nonexisting_name(self, getpwnam):
-            getpwnam.side_effect = KeyError('user')
+    @patch('celery.platforms.parse_uid')
+    @patch('os.setuid')
+    def test_setuid(self, _setuid, parse_uid):
+        parse_uid.return_value = 5001
+        setuid('user')
+        parse_uid.assert_called_with('user')
+        _setuid.assert_called_with(5001)
 
-            with self.assertRaises(KeyError):
-                parse_uid('user')
+    @patch('celery.platforms.parse_gid')
+    @patch('os.setgid')
+    def test_setgid(self, _setgid, parse_gid):
+        parse_gid.return_value = 50001
+        setgid('group')
+        parse_gid.assert_called_with('group')
+        _setgid.assert_called_with(50001)
 
-        def test_parse_gid_when_int(self):
-            self.assertEqual(parse_gid(50001), 50001)
+    def test_parse_uid_when_int(self):
+        self.assertEqual(parse_uid(5001), 5001)
 
-        @patch('grp.getgrnam')
-        def test_parse_gid_when_existing_name(self, getgrnam):
+    @patch('pwd.getpwnam')
+    def test_parse_uid_when_existing_name(self, getpwnam):
 
-            class grent(object):
-                gr_gid = 50001
+        class pwent(object):
+            pw_uid = 5001
+
+        getpwnam.return_value = pwent()
+        self.assertEqual(parse_uid('user'), 5001)
 
-            getgrnam.return_value = grent()
-            self.assertEqual(parse_gid('group'), 50001)
+    @patch('pwd.getpwnam')
+    def test_parse_uid_when_nonexisting_name(self, getpwnam):
+        getpwnam.side_effect = KeyError('user')
 
-        @patch('grp.getgrnam')
-        def test_parse_gid_when_nonexisting_name(self, getgrnam):
-            getgrnam.side_effect = KeyError('group')
+        with self.assertRaises(KeyError):
+            parse_uid('user')
 
-            with self.assertRaises(KeyError):
-                parse_gid('group')
+    def test_parse_gid_when_int(self):
+        self.assertEqual(parse_gid(50001), 50001)
 
-    class test_initgroups(Case):
+    @patch('grp.getgrnam')
+    def test_parse_gid_when_existing_name(self, getgrnam):
 
-        @patch('pwd.getpwuid')
-        @patch('os.initgroups', create=True)
-        def test_with_initgroups(self, initgroups_, getpwuid):
+        class grent(object):
+            gr_gid = 50001
+
+        getgrnam.return_value = grent()
+        self.assertEqual(parse_gid('group'), 50001)
+
+    @patch('grp.getgrnam')
+    def test_parse_gid_when_nonexisting_name(self, getgrnam):
+        getgrnam.side_effect = KeyError('group')
+
+        with self.assertRaises(KeyError):
+            parse_gid('group')
+
+
+@skip_if_win32()
+class test_initgroups(Case):
+
+    @patch('pwd.getpwuid')
+    @patch('os.initgroups', create=True)
+    def test_with_initgroups(self, initgroups_, getpwuid):
+        getpwuid.return_value = ['user']
+        initgroups(5001, 50001)
+        initgroups_.assert_called_with('user', 50001)
+
+    @patch('celery.platforms.setgroups')
+    @patch('grp.getgrall')
+    @patch('pwd.getpwuid')
+    def test_without_initgroups(self, getpwuid, getgrall, setgroups):
+        prev = getattr(os, 'initgroups', None)
+        try:
+            delattr(os, 'initgroups')
+        except AttributeError:
+            pass
+        try:
             getpwuid.return_value = ['user']
+
+            class grent(object):
+                gr_mem = ['user']
+
+                def __init__(self, gid):
+                    self.gr_gid = gid
+
+            getgrall.return_value = [grent(1), grent(2), grent(3)]
             initgroups(5001, 50001)
-            initgroups_.assert_called_with('user', 50001)
-
-        @patch('celery.platforms.setgroups')
-        @patch('grp.getgrall')
-        @patch('pwd.getpwuid')
-        def test_without_initgroups(self, getpwuid, getgrall, setgroups):
-            prev = getattr(os, 'initgroups', None)
-            try:
-                delattr(os, 'initgroups')
-            except AttributeError:
-                pass
-            try:
-                getpwuid.return_value = ['user']
-
-                class grent(object):
-                    gr_mem = ['user']
-
-                    def __init__(self, gid):
-                        self.gr_gid = gid
-
-                getgrall.return_value = [grent(1), grent(2), grent(3)]
-                initgroups(5001, 50001)
-                setgroups.assert_called_with([1, 2, 3])
-            finally:
-                if prev:
-                    os.initgroups = prev
-
-    class test_detached(Case):
-
-        def test_without_resource(self):
-            prev, platforms.resource = platforms.resource, None
-            try:
-                with self.assertRaises(RuntimeError):
-                    detached()
-            finally:
-                platforms.resource = prev
-
-        @patch('celery.platforms._create_pidlock')
-        @patch('celery.platforms.signals')
-        @patch('celery.platforms.maybe_drop_privileges')
-        @patch('os.geteuid')
-        @patch(open_fqdn)
-        def test_default(self, open, geteuid, maybe_drop,
-                         signals, pidlock):
-            geteuid.return_value = 0
-            context = detached(uid='user', gid='group')
-            self.assertIsInstance(context, DaemonContext)
-            signals.reset.assert_called_with('SIGCLD')
-            maybe_drop.assert_called_with(uid='user', gid='group')
-            open.return_value = Mock()
-
-            geteuid.return_value = 5001
-            context = detached(uid='user', gid='group', logfile='/foo/bar')
-            self.assertIsInstance(context, DaemonContext)
-            self.assertTrue(context.after_chdir)
-            context.after_chdir()
-            open.assert_called_with('/foo/bar', 'a')
-            open.return_value.close.assert_called_with()
-
-            context = detached(pidfile='/foo/bar/pid')
-            self.assertIsInstance(context, DaemonContext)
-            self.assertTrue(context.after_chdir)
-            context.after_chdir()
-            pidlock.assert_called_with('/foo/bar/pid')
-
-    class test_DaemonContext(Case):
-
-        @patch('os.fork')
-        @patch('os.setsid')
-        @patch('os._exit')
-        @patch('os.chdir')
-        @patch('os.umask')
-        @patch('os.close')
-        @patch('os.closerange')
-        @patch('os.open')
-        @patch('os.dup2')
-        def test_open(self, dup2, open, close, closer, umask, chdir,
-                      _exit, setsid, fork):
-            x = DaemonContext(workdir='/opt/workdir', umask=0o22)
-            x.stdfds = [0, 1, 2]
-
-            fork.return_value = 0
-            with x:
-                self.assertTrue(x._is_open)
-                with x:
-                    pass
-            self.assertEqual(fork.call_count, 2)
-            setsid.assert_called_with()
-            self.assertFalse(_exit.called)
-
-            chdir.assert_called_with(x.workdir)
-            umask.assert_called_with(0o22)
-            self.assertTrue(dup2.called)
-
-            fork.reset_mock()
-            fork.return_value = 1
-            x = DaemonContext(workdir='/opt/workdir')
-            x.stdfds = [0, 1, 2]
-            with x:
-                pass
-            self.assertEqual(fork.call_count, 1)
-            _exit.assert_called_with(0)
+            setgroups.assert_called_with([1, 2, 3])
+        finally:
+            if prev:
+                os.initgroups = prev
 
-            x = DaemonContext(workdir='/opt/workdir', fake=True)
-            x.stdfds = [0, 1, 2]
-            x._detach = Mock()
-            with x:
-                pass
-            self.assertFalse(x._detach.called)
 
-            x.after_chdir = Mock()
+@skip_if_win32()
+class test_detached(Case):
+
+    def test_without_resource(self):
+        prev, platforms.resource = platforms.resource, None
+        try:
+            with self.assertRaises(RuntimeError):
+                detached()
+        finally:
+            platforms.resource = prev
+
+    @patch('celery.platforms._create_pidlock')
+    @patch('celery.platforms.signals')
+    @patch('celery.platforms.maybe_drop_privileges')
+    @patch('os.geteuid')
+    @patch(open_fqdn)
+    def test_default(self, open, geteuid, maybe_drop,
+                     signals, pidlock):
+        geteuid.return_value = 0
+        context = detached(uid='user', gid='group')
+        self.assertIsInstance(context, DaemonContext)
+        signals.reset.assert_called_with('SIGCLD')
+        maybe_drop.assert_called_with(uid='user', gid='group')
+        open.return_value = Mock()
+
+        geteuid.return_value = 5001
+        context = detached(uid='user', gid='group', logfile='/foo/bar')
+        self.assertIsInstance(context, DaemonContext)
+        self.assertTrue(context.after_chdir)
+        context.after_chdir()
+        open.assert_called_with('/foo/bar', 'a')
+        open.return_value.close.assert_called_with()
+
+        context = detached(pidfile='/foo/bar/pid')
+        self.assertIsInstance(context, DaemonContext)
+        self.assertTrue(context.after_chdir)
+        context.after_chdir()
+        pidlock.assert_called_with('/foo/bar/pid')
+
+
+@skip_if_win32()
+class test_DaemonContext(Case):
+
+    @patch('os.fork')
+    @patch('os.setsid')
+    @patch('os._exit')
+    @patch('os.chdir')
+    @patch('os.umask')
+    @patch('os.close')
+    @patch('os.closerange')
+    @patch('os.open')
+    @patch('os.dup2')
+    def test_open(self, dup2, open, close, closer, umask, chdir,
+                  _exit, setsid, fork):
+        x = DaemonContext(workdir='/opt/workdir', umask=0o22)
+        x.stdfds = [0, 1, 2]
+
+        fork.return_value = 0
+        with x:
+            self.assertTrue(x._is_open)
             with x:
                 pass
-            x.after_chdir.assert_called_with()
-
-            x = DaemonContext(workdir='/opt/workdir', umask='0755')
-            self.assertEqual(x.umask, 493)
-            x = DaemonContext(workdir='/opt/workdir', umask='493')
-            self.assertEqual(x.umask, 493)
-
-            x.redirect_to_null(None)
-
-            with patch('celery.platforms.mputil') as mputil:
-                x = DaemonContext(after_forkers=True)
-                x.open()
-                mputil._run_after_forkers.assert_called_with()
-                x = DaemonContext(after_forkers=False)
-                x.open()
-
-    class test_Pidfile(Case):
-
-        @patch('celery.platforms.Pidfile')
-        def test_create_pidlock(self, Pidfile):
-            p = Pidfile.return_value = Mock()
-            p.is_locked.return_value = True
-            p.remove_if_stale.return_value = False
-            with override_stdouts() as (_, err):
-                with self.assertRaises(SystemExit):
-                    create_pidlock('/var/pid')
-                self.assertIn('already exists', err.getvalue())
-
-            p.remove_if_stale.return_value = True
-            ret = create_pidlock('/var/pid')
-            self.assertIs(ret, p)
-
-        def test_context(self):
-            p = Pidfile('/var/pid')
-            p.write_pid = Mock()
-            p.remove = Mock()
-
-            with p as _p:
-                self.assertIs(_p, p)
-            p.write_pid.assert_called_with()
-            p.remove.assert_called_with()
+        self.assertEqual(fork.call_count, 2)
+        setsid.assert_called_with()
+        self.assertFalse(_exit.called)
+
+        chdir.assert_called_with(x.workdir)
+        umask.assert_called_with(0o22)
+        self.assertTrue(dup2.called)
+
+        fork.reset_mock()
+        fork.return_value = 1
+        x = DaemonContext(workdir='/opt/workdir')
+        x.stdfds = [0, 1, 2]
+        with x:
+            pass
+        self.assertEqual(fork.call_count, 1)
+        _exit.assert_called_with(0)
+
+        x = DaemonContext(workdir='/opt/workdir', fake=True)
+        x.stdfds = [0, 1, 2]
+        x._detach = Mock()
+        with x:
+            pass
+        self.assertFalse(x._detach.called)
+
+        x.after_chdir = Mock()
+        with x:
+            pass
+        x.after_chdir.assert_called_with()
+
+        x = DaemonContext(workdir='/opt/workdir', umask='0755')
+        self.assertEqual(x.umask, 493)
+        x = DaemonContext(workdir='/opt/workdir', umask='493')
+        self.assertEqual(x.umask, 493)
+
+        x.redirect_to_null(None)
+
+        with patch('celery.platforms.mputil') as mputil:
+            x = DaemonContext(after_forkers=True)
+            x.open()
+            mputil._run_after_forkers.assert_called_with()
+            x = DaemonContext(after_forkers=False)
+            x.open()
+
+
+@skip_if_win32()
+class test_Pidfile(Case):
+
+    @patch('celery.platforms.Pidfile')
+    def test_create_pidlock(self, Pidfile):
+        p = Pidfile.return_value = Mock()
+        p.is_locked.return_value = True
+        p.remove_if_stale.return_value = False
+        with override_stdouts() as (_, err):
+            with self.assertRaises(SystemExit):
+                create_pidlock('/var/pid')
+            self.assertIn('already exists', err.getvalue())
+
+        p.remove_if_stale.return_value = True
+        ret = create_pidlock('/var/pid')
+        self.assertIs(ret, p)
+
+    def test_context(self):
+        p = Pidfile('/var/pid')
+        p.write_pid = Mock()
+        p.remove = Mock()
+
+        with p as _p:
+            self.assertIs(_p, p)
+        p.write_pid.assert_called_with()
+        p.remove.assert_called_with()
+
+    def test_acquire_raises_LockFailed(self):
+        p = Pidfile('/var/pid')
+        p.write_pid = Mock()
+        p.write_pid.side_effect = OSError()
+
+        with self.assertRaises(LockFailed):
+            with p:
+                pass
 
-        def test_acquire_raises_LockFailed(self):
+    @patch('os.path.exists')
+    def test_is_locked(self, exists):
+        p = Pidfile('/var/pid')
+        exists.return_value = True
+        self.assertTrue(p.is_locked())
+        exists.return_value = False
+        self.assertFalse(p.is_locked())
+
+    def test_read_pid(self):
+        with mock_open() as s:
+            s.write('1816\n')
+            s.seek(0)
             p = Pidfile('/var/pid')
-            p.write_pid = Mock()
-            p.write_pid.side_effect = OSError()
-
-            with self.assertRaises(LockFailed):
-                with p:
-                    pass
+            self.assertEqual(p.read_pid(), 1816)
 
-        @patch('os.path.exists')
-        def test_is_locked(self, exists):
-            p = Pidfile('/var/pid')
-            exists.return_value = True
-            self.assertTrue(p.is_locked())
-            exists.return_value = False
-            self.assertFalse(p.is_locked())
-
-        def test_read_pid(self):
-            with mock_open() as s:
-                s.write('1816\n')
-                s.seek(0)
-                p = Pidfile('/var/pid')
-                self.assertEqual(p.read_pid(), 1816)
-
-        def test_read_pid_partially_written(self):
-            with mock_open() as s:
-                s.write('1816')
-                s.seek(0)
-                p = Pidfile('/var/pid')
-                with self.assertRaises(ValueError):
-                    p.read_pid()
-
-        def test_read_pid_raises_ENOENT(self):
-            exc = IOError()
-            exc.errno = errno.ENOENT
-            with mock_open(side_effect=exc):
-                p = Pidfile('/var/pid')
-                self.assertIsNone(p.read_pid())
-
-        def test_read_pid_raises_IOError(self):
-            exc = IOError()
-            exc.errno = errno.EAGAIN
-            with mock_open(side_effect=exc):
-                p = Pidfile('/var/pid')
-                with self.assertRaises(IOError):
-                    p.read_pid()
-
-        def test_read_pid_bogus_pidfile(self):
-            with mock_open() as s:
-                s.write('eighteensixteen\n')
-                s.seek(0)
-                p = Pidfile('/var/pid')
-                with self.assertRaises(ValueError):
-                    p.read_pid()
-
-        @patch('os.unlink')
-        def test_remove(self, unlink):
-            unlink.return_value = True
+    def test_read_pid_partially_written(self):
+        with mock_open() as s:
+            s.write('1816')
+            s.seek(0)
             p = Pidfile('/var/pid')
-            p.remove()
-            unlink.assert_called_with(p.path)
+            with self.assertRaises(ValueError):
+                p.read_pid()
 
-        @patch('os.unlink')
-        def test_remove_ENOENT(self, unlink):
-            exc = OSError()
-            exc.errno = errno.ENOENT
-            unlink.side_effect = exc
+    def test_read_pid_raises_ENOENT(self):
+        exc = IOError()
+        exc.errno = errno.ENOENT
+        with mock_open(side_effect=exc):
             p = Pidfile('/var/pid')
-            p.remove()
-            unlink.assert_called_with(p.path)
+            self.assertIsNone(p.read_pid())
 
-        @patch('os.unlink')
-        def test_remove_EACCES(self, unlink):
-            exc = OSError()
-            exc.errno = errno.EACCES
-            unlink.side_effect = exc
+    def test_read_pid_raises_IOError(self):
+        exc = IOError()
+        exc.errno = errno.EAGAIN
+        with mock_open(side_effect=exc):
             p = Pidfile('/var/pid')
-            p.remove()
-            unlink.assert_called_with(p.path)
+            with self.assertRaises(IOError):
+                p.read_pid()
 
-        @patch('os.unlink')
-        def test_remove_OSError(self, unlink):
-            exc = OSError()
-            exc.errno = errno.EAGAIN
-            unlink.side_effect = exc
+    def test_read_pid_bogus_pidfile(self):
+        with mock_open() as s:
+            s.write('eighteensixteen\n')
+            s.seek(0)
             p = Pidfile('/var/pid')
-            with self.assertRaises(OSError):
-                p.remove()
-            unlink.assert_called_with(p.path)
-
-        @patch('os.kill')
-        def test_remove_if_stale_process_alive(self, kill):
+            with self.assertRaises(ValueError):
+                p.read_pid()
+
+    @patch('os.unlink')
+    def test_remove(self, unlink):
+        unlink.return_value = True
+        p = Pidfile('/var/pid')
+        p.remove()
+        unlink.assert_called_with(p.path)
+
+    @patch('os.unlink')
+    def test_remove_ENOENT(self, unlink):
+        exc = OSError()
+        exc.errno = errno.ENOENT
+        unlink.side_effect = exc
+        p = Pidfile('/var/pid')
+        p.remove()
+        unlink.assert_called_with(p.path)
+
+    @patch('os.unlink')
+    def test_remove_EACCES(self, unlink):
+        exc = OSError()
+        exc.errno = errno.EACCES
+        unlink.side_effect = exc
+        p = Pidfile('/var/pid')
+        p.remove()
+        unlink.assert_called_with(p.path)
+
+    @patch('os.unlink')
+    def test_remove_OSError(self, unlink):
+        exc = OSError()
+        exc.errno = errno.EAGAIN
+        unlink.side_effect = exc
+        p = Pidfile('/var/pid')
+        with self.assertRaises(OSError):
+            p.remove()
+        unlink.assert_called_with(p.path)
+
+    @patch('os.kill')
+    def test_remove_if_stale_process_alive(self, kill):
+        p = Pidfile('/var/pid')
+        p.read_pid = Mock()
+        p.read_pid.return_value = 1816
+        kill.return_value = 0
+        self.assertFalse(p.remove_if_stale())
+        kill.assert_called_with(1816, 0)
+        p.read_pid.assert_called_with()
+
+        kill.side_effect = OSError()
+        kill.side_effect.errno = errno.ENOENT
+        self.assertFalse(p.remove_if_stale())
+
+    @patch('os.kill')
+    def test_remove_if_stale_process_dead(self, kill):
+        with override_stdouts():
             p = Pidfile('/var/pid')
             p.read_pid = Mock()
             p.read_pid.return_value = 1816
-            kill.return_value = 0
-            self.assertFalse(p.remove_if_stale())
+            p.remove = Mock()
+            exc = OSError()
+            exc.errno = errno.ESRCH
+            kill.side_effect = exc
+            self.assertTrue(p.remove_if_stale())
             kill.assert_called_with(1816, 0)
-            p.read_pid.assert_called_with()
-
-            kill.side_effect = OSError()
-            kill.side_effect.errno = errno.ENOENT
-            self.assertFalse(p.remove_if_stale())
-
-        @patch('os.kill')
-        def test_remove_if_stale_process_dead(self, kill):
-            with override_stdouts():
-                p = Pidfile('/var/pid')
-                p.read_pid = Mock()
-                p.read_pid.return_value = 1816
-                p.remove = Mock()
-                exc = OSError()
-                exc.errno = errno.ESRCH
-                kill.side_effect = exc
-                self.assertTrue(p.remove_if_stale())
-                kill.assert_called_with(1816, 0)
-                p.remove.assert_called_with()
-
-        def test_remove_if_stale_broken_pid(self):
-            with override_stdouts():
-                p = Pidfile('/var/pid')
-                p.read_pid = Mock()
-                p.read_pid.side_effect = ValueError()
-                p.remove = Mock()
-
-                self.assertTrue(p.remove_if_stale())
-                p.remove.assert_called_with()
-
-        def test_remove_if_stale_no_pidfile(self):
+            p.remove.assert_called_with()
+
+    def test_remove_if_stale_broken_pid(self):
+        with override_stdouts():
             p = Pidfile('/var/pid')
             p.read_pid = Mock()
-            p.read_pid.return_value = None
+            p.read_pid.side_effect = ValueError()
             p.remove = Mock()
 
             self.assertTrue(p.remove_if_stale())
             p.remove.assert_called_with()
 
-        @patch('os.fsync')
-        @patch('os.getpid')
-        @patch('os.open')
-        @patch('os.fdopen')
-        @patch(open_fqdn)
-        def test_write_pid(self, open_, fdopen, osopen, getpid, fsync):
-            getpid.return_value = 1816
-            osopen.return_value = 13
-            w = fdopen.return_value = WhateverIO()
-            w.close = Mock()
-            r = open_.return_value = WhateverIO()
-            r.write('1816\n')
-            r.seek(0)
-
-            p = Pidfile('/var/pid')
+    def test_remove_if_stale_no_pidfile(self):
+        p = Pidfile('/var/pid')
+        p.read_pid = Mock()
+        p.read_pid.return_value = None
+        p.remove = Mock()
+
+        self.assertTrue(p.remove_if_stale())
+        p.remove.assert_called_with()
+
+    @patch('os.fsync')
+    @patch('os.getpid')
+    @patch('os.open')
+    @patch('os.fdopen')
+    @patch(open_fqdn)
+    def test_write_pid(self, open_, fdopen, osopen, getpid, fsync):
+        getpid.return_value = 1816
+        osopen.return_value = 13
+        w = fdopen.return_value = WhateverIO()
+        w.close = Mock()
+        r = open_.return_value = WhateverIO()
+        r.write('1816\n')
+        r.seek(0)
+
+        p = Pidfile('/var/pid')
+        p.write_pid()
+        w.seek(0)
+        self.assertEqual(w.readline(), '1816\n')
+        self.assertTrue(w.close.called)
+        getpid.assert_called_with()
+        osopen.assert_called_with(
+            p.path, platforms.PIDFILE_FLAGS, platforms.PIDFILE_MODE,
+        )
+        fdopen.assert_called_with(13, 'w')
+        fsync.assert_called_with(13)
+        open_.assert_called_with(p.path)
+
+    @patch('os.fsync')
+    @patch('os.getpid')
+    @patch('os.open')
+    @patch('os.fdopen')
+    @patch(open_fqdn)
+    def test_write_reread_fails(self, open_, fdopen,
+                                osopen, getpid, fsync):
+        getpid.return_value = 1816
+        osopen.return_value = 13
+        w = fdopen.return_value = WhateverIO()
+        w.close = Mock()
+        r = open_.return_value = WhateverIO()
+        r.write('11816\n')
+        r.seek(0)
+
+        p = Pidfile('/var/pid')
+        with self.assertRaises(LockFailed):
             p.write_pid()
-            w.seek(0)
-            self.assertEqual(w.readline(), '1816\n')
-            self.assertTrue(w.close.called)
-            getpid.assert_called_with()
-            osopen.assert_called_with(p.path, platforms.PIDFILE_FLAGS,
-                                      platforms.PIDFILE_MODE)
-            fdopen.assert_called_with(13, 'w')
-            fsync.assert_called_with(13)
-            open_.assert_called_with(p.path)
-
-        @patch('os.fsync')
-        @patch('os.getpid')
-        @patch('os.open')
-        @patch('os.fdopen')
-        @patch(open_fqdn)
-        def test_write_reread_fails(self, open_, fdopen,
-                                    osopen, getpid, fsync):
-            getpid.return_value = 1816
-            osopen.return_value = 13
-            w = fdopen.return_value = WhateverIO()
-            w.close = Mock()
-            r = open_.return_value = WhateverIO()
-            r.write('11816\n')
-            r.seek(0)
 
-            p = Pidfile('/var/pid')
-            with self.assertRaises(LockFailed):
-                p.write_pid()
 
-    class test_setgroups(Case):
+class test_setgroups(Case):
 
-        @patch('os.setgroups', create=True)
-        def test_setgroups_hack_ValueError(self, setgroups):
+    @patch('os.setgroups', create=True)
+    def test_setgroups_hack_ValueError(self, setgroups):
 
-            def on_setgroups(groups):
-                if len(groups) <= 200:
-                    setgroups.return_value = True
-                    return
-                raise ValueError()
-            setgroups.side_effect = on_setgroups
+        def on_setgroups(groups):
+            if len(groups) <= 200:
+                setgroups.return_value = True
+                return
+            raise ValueError()
+        setgroups.side_effect = on_setgroups
+        _setgroups_hack(list(range(400)))
+
+        setgroups.side_effect = ValueError()
+        with self.assertRaises(ValueError):
             _setgroups_hack(list(range(400)))
 
-            setgroups.side_effect = ValueError()
-            with self.assertRaises(ValueError):
-                _setgroups_hack(list(range(400)))
+    @patch('os.setgroups', create=True)
+    def test_setgroups_hack_OSError(self, setgroups):
+        exc = OSError()
+        exc.errno = errno.EINVAL
 
-        @patch('os.setgroups', create=True)
-        def test_setgroups_hack_OSError(self, setgroups):
-            exc = OSError()
-            exc.errno = errno.EINVAL
+        def on_setgroups(groups):
+            if len(groups) <= 200:
+                setgroups.return_value = True
+                return
+            raise exc
+        setgroups.side_effect = on_setgroups
 
-            def on_setgroups(groups):
-                if len(groups) <= 200:
-                    setgroups.return_value = True
-                    return
-                raise exc
-            setgroups.side_effect = on_setgroups
+        _setgroups_hack(list(range(400)))
 
+        setgroups.side_effect = exc
+        with self.assertRaises(OSError):
             _setgroups_hack(list(range(400)))
 
-            setgroups.side_effect = exc
-            with self.assertRaises(OSError):
-                _setgroups_hack(list(range(400)))
+        exc2 = OSError()
+        exc.errno = errno.ESRCH
+        setgroups.side_effect = exc2
+        with self.assertRaises(OSError):
+            _setgroups_hack(list(range(400)))
 
-            exc2 = OSError()
-            exc.errno = errno.ESRCH
-            setgroups.side_effect = exc2
-            with self.assertRaises(OSError):
-                _setgroups_hack(list(range(400)))
-
-        @patch('os.sysconf')
-        @patch('celery.platforms._setgroups_hack')
-        def test_setgroups(self, hack, sysconf):
-            sysconf.return_value = 100
+    @patch('os.sysconf')
+    @patch('celery.platforms._setgroups_hack')
+    def test_setgroups(self, hack, sysconf):
+        sysconf.return_value = 100
+        setgroups(list(range(400)))
+        hack.assert_called_with(list(range(100)))
+
+    @patch('os.sysconf')
+    @patch('celery.platforms._setgroups_hack')
+    def test_setgroups_sysconf_raises(self, hack, sysconf):
+        sysconf.side_effect = ValueError()
+        setgroups(list(range(400)))
+        hack.assert_called_with(list(range(400)))
+
+    @patch('os.getgroups')
+    @patch('os.sysconf')
+    @patch('celery.platforms._setgroups_hack')
+    def test_setgroups_raises_ESRCH(self, hack, sysconf, getgroups):
+        sysconf.side_effect = ValueError()
+        esrch = OSError()
+        esrch.errno = errno.ESRCH
+        hack.side_effect = esrch
+        with self.assertRaises(OSError):
             setgroups(list(range(400)))
-            hack.assert_called_with(list(range(100)))
 
-        @patch('os.sysconf')
-        @patch('celery.platforms._setgroups_hack')
-        def test_setgroups_sysconf_raises(self, hack, sysconf):
-            sysconf.side_effect = ValueError()
-            setgroups(list(range(400)))
-            hack.assert_called_with(list(range(400)))
-
-        @patch('os.getgroups')
-        @patch('os.sysconf')
-        @patch('celery.platforms._setgroups_hack')
-        def test_setgroups_raises_ESRCH(self, hack, sysconf, getgroups):
-            sysconf.side_effect = ValueError()
-            esrch = OSError()
-            esrch.errno = errno.ESRCH
-            hack.side_effect = esrch
-            with self.assertRaises(OSError):
-                setgroups(list(range(400)))
-
-        @patch('os.getgroups')
-        @patch('os.sysconf')
-        @patch('celery.platforms._setgroups_hack')
-        def test_setgroups_raises_EPERM(self, hack, sysconf, getgroups):
-            sysconf.side_effect = ValueError()
-            eperm = OSError()
-            eperm.errno = errno.EPERM
-            hack.side_effect = eperm
-            getgroups.return_value = list(range(400))
+    @patch('os.getgroups')
+    @patch('os.sysconf')
+    @patch('celery.platforms._setgroups_hack')
+    def test_setgroups_raises_EPERM(self, hack, sysconf, getgroups):
+        sysconf.side_effect = ValueError()
+        eperm = OSError()
+        eperm.errno = errno.EPERM
+        hack.side_effect = eperm
+        getgroups.return_value = list(range(400))
+        setgroups(list(range(400)))
+        getgroups.assert_called_with()
+
+        getgroups.return_value = [1000]
+        with self.assertRaises(OSError):
             setgroups(list(range(400)))
-            getgroups.assert_called_with()
-
-            getgroups.return_value = [1000]
-            with self.assertRaises(OSError):
-                setgroups(list(range(400)))
-            getgroups.assert_called_with()
+        getgroups.assert_called_with()
 
 
 class test_check_privileges(Case):

+ 3 - 9
celery/tests/utils/test_sysinfo.py

@@ -1,17 +1,14 @@
 from __future__ import absolute_import, unicode_literals
 
-import os
-
 from celery.utils.sysinfo import load_average, df
 
-from celery.tests.case import Case, SkipTest, patch
+from celery.tests.case import Case, patch, skip_unless_symbol
 
 
+@skip_unless_symbol('os.getloadavg')
 class test_load_average(Case):
 
     def test_avg(self):
-        if not hasattr(os, 'getloadavg'):
-            raise SkipTest('getloadavg not available')
         with patch('os.getloadavg') as getloadavg:
             getloadavg.return_value = 0.54736328125, 0.6357421875, 0.69921875
             l = load_average()
@@ -19,13 +16,10 @@ class test_load_average(Case):
             self.assertEqual(l, (0.55, 0.64, 0.7))
 
 
+@skip_unless_symbol('posix.statvfs_result')
 class test_df(Case):
 
     def test_df(self):
-        try:
-            from posix import statvfs_result  # noqa
-        except ImportError:
-            raise SkipTest('statvfs not available')
         x = df('/')
         self.assertTrue(x.total_blocks)
         self.assertTrue(x.available)

+ 2 - 4
celery/tests/utils/test_term.py

@@ -7,15 +7,13 @@ from celery.utils import term
 from celery.utils.term import colored, fg
 from celery.five import text_t
 
-from celery.tests.case import Case, SkipTest
+from celery.tests.case import Case, skip_if_win32
 
 
+@skip_if_win32()
 class test_colored(Case):
 
     def setUp(self):
-        if sys.platform == 'win32':
-            raise SkipTest('Colors not supported on Windows')
-
         self._prev_encoding = sys.getdefaultencoding
 
         def getdefaultencoding():

+ 2 - 4
celery/tests/worker/test_components.py

@@ -5,10 +5,9 @@ from __future__ import absolute_import, unicode_literals
 # point [-ask]
 
 from celery.exceptions import ImproperlyConfigured
-from celery.platforms import IS_WINDOWS
 from celery.worker.components import Beat, Hub, Pool, Timer
 
-from celery.tests.case import AppCase, Mock, SkipTest, patch
+from celery.tests.case import AppCase, Mock, patch, skip_if_win32
 
 
 class test_Timer(AppCase):
@@ -61,9 +60,8 @@ class test_Pool(AppCase):
         comp.close(w)
         comp.terminate(w)
 
+    @skip_if_win32()
     def test_create_when_eventloop(self):
-        if IS_WINDOWS:
-            raise SkipTest('Win32')
         w = Mock()
         w.use_eventloop = w.pool_putlocks = w.pool_cls.uses_semaphore = True
         comp = Pool(w)

+ 5 - 6
celery/tests/worker/test_consumer.py

@@ -13,7 +13,9 @@ from celery.worker.consumer.heart import Heart
 from celery.worker.consumer.mingle import Mingle
 from celery.worker.consumer.tasks import Tasks
 
-from celery.tests.case import AppCase, ContextMock, Mock, SkipTest, call, patch
+from celery.tests.case import (
+    AppCase, ContextMock, Mock, call, patch, skip_if_python3,
+)
 
 
 class test_Consumer(AppCase):
@@ -43,14 +45,11 @@ class test_Consumer(AppCase):
         c = self.get_consumer()
         self.assertIsNone(c.task_buckets['fooxasdwx.wewe'])
 
+    @skip_if_python3(reason='buffer type not available')
     def test_dump_body_buffer(self):
         msg = Mock()
         msg.body = 'str'
-        try:
-            buf = buffer(msg.body)
-        except NameError:
-            raise SkipTest('buffer type not available')
-        self.assertTrue(dump_body(msg, buf))
+        self.assertTrue(dump_body(msg, buffer(msg.body)))
 
     def test_sets_heartbeat(self):
         c = self.get_consumer(amqheartbeat=10)

+ 7 - 11
celery/tests/worker/test_request.py

@@ -45,11 +45,10 @@ from celery.tests.case import (
     AppCase,
     Case,
     Mock,
-    SkipTest,
     TaskMessage,
-    assert_signal_called,
     task_message_from_sig,
     patch,
+    skip_if_python3,
 )
 
 
@@ -124,12 +123,9 @@ def jail(app, task_id, name, args, kwargs):
     ).retval
 
 
+@skip_if_python3
 class test_default_encode(AppCase):
 
-    def setup(self):
-        if sys.version_info >= (3, 0):
-            raise SkipTest('py3k: not relevant')
-
     def test_jython(self):
         prev, sys.platform = sys.platform, 'java 1.6.1'
         try:
@@ -430,7 +426,7 @@ class test_Request(RequestCase):
         signum = signal.SIGTERM
         job = self.get_request(self.mytask.s(1, f='x'))
         job._apply_result = Mock(name='_apply_result')
-        with assert_signal_called(
+        with self.assert_signal_called(
                 task_revoked, sender=job.task, request=job,
                 terminated=True, expired=False, signum=signum):
             job.time_start = monotonic()
@@ -446,7 +442,7 @@ class test_Request(RequestCase):
         pool = Mock()
         signum = signal.SIGTERM
         job = self.get_request(self.mytask.s(1, f='x'))
-        with assert_signal_called(
+        with self.assert_signal_called(
                 task_revoked, sender=job.task, request=job,
                 terminated=True, expired=False, signum=signum):
             job.time_start = monotonic()
@@ -467,7 +463,7 @@ class test_Request(RequestCase):
         job = self.get_request(self.mytask.s(1, f='x').set(
             expires=datetime.utcnow() - timedelta(days=1)
         ))
-        with assert_signal_called(
+        with self.assert_signal_called(
                 task_revoked, sender=job.task, request=job,
                 terminated=False, expired=True, signum=None):
             job.revoked()
@@ -506,7 +502,7 @@ class test_Request(RequestCase):
 
     def test_revoked(self):
         job = self.xRequest()
-        with assert_signal_called(
+        with self.assert_signal_called(
                 task_revoked, sender=job.task, request=job,
                 terminated=False, expired=False, signum=None):
             revoked.add(job.id)
@@ -555,7 +551,7 @@ class test_Request(RequestCase):
         signum = signal.SIGTERM
         pool = Mock()
         job = self.xRequest()
-        with assert_signal_called(
+        with self.assert_signal_called(
                 task_revoked, sender=job.task, request=job,
                 terminated=True, expired=False, signum=signum):
             job.terminate(pool, signal='TERM')

+ 2 - 2
celery/tests/worker/test_worker.py

@@ -30,7 +30,7 @@ from celery.utils import worker_direct
 from celery.utils.serialization import pickle
 from celery.utils.timer2 import Timer
 
-from celery.tests.case import AppCase, Mock, SkipTest, TaskMessage, patch
+from celery.tests.case import AppCase, Mock, TaskMessage, patch, todo
 
 
 def MockStep(step=None):
@@ -849,8 +849,8 @@ class test_WorkController(AppCase):
             self.worker._send_worker_shutdown()
             ws.send.assert_called_with(sender=self.worker)
 
+    @todo('unstable test')
     def test_process_shutdown_on_worker_shutdown(self):
-        raise SkipTest('unstable test')
         from celery.concurrency.prefork import process_destructor
         from celery.concurrency.asynpool import Worker
         with patch('celery.signals.worker_process_shutdown') as ws: