123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215 |
- from __future__ import generators
- import os
- import sys
- import __builtin__
- from StringIO import StringIO
- from billiard.utils.functional import wraps
- class GeneratorContextManager(object):
- def __init__(self, gen):
- self.gen = gen
- def __enter__(self):
- try:
- return self.gen.next()
- except StopIteration:
- raise RuntimeError("generator didn't yield")
- def __exit__(self, type, value, traceback):
- if type is None:
- try:
- self.gen.next()
- except StopIteration:
- return
- else:
- raise RuntimeError("generator didn't stop")
- else:
- try:
- self.gen.throw(type, value, traceback)
- raise RuntimeError("generator didn't stop after throw()")
- except StopIteration:
- return True
- except AttributeError:
- raise value
- except:
- if sys.exc_info()[1] is not value:
- raise
- def fallback_contextmanager(fun):
- def helper(*args, **kwds):
- return GeneratorContextManager(fun(*args, **kwds))
- return helper
- def execute_context(context, fun):
- val = context.__enter__()
- exc_info = (None, None, None)
- retval = None
- try:
- retval = fun(val)
- except:
- exc_info = sys.exc_info()
- context.__exit__(*exc_info)
- return retval
- try:
- from contextlib import contextmanager
- except ImportError:
- contextmanager = fallback_contextmanager
- from celery.utils import noop
- @contextmanager
- def eager_tasks():
- from celery import conf
- prev = conf.ALWAYS_EAGER
- conf.ALWAYS_EAGER = True
- yield True
- conf.ALWAYS_EAGER = prev
- def with_environ(env_name, env_value):
- def _envpatched(fun):
- @wraps(fun)
- def _patch_environ(*args, **kwargs):
- prev_val = os.environ.get(env_name)
- os.environ[env_name] = env_value
- try:
- return fun(*args, **kwargs)
- finally:
- if prev_val is not None:
- os.environ[env_name] = prev_val
- return _patch_environ
- return _envpatched
- def sleepdeprived(fun):
- @wraps(fun)
- def _sleepdeprived(*args, **kwargs):
- import time
- old_sleep = time.sleep
- time.sleep = noop
- try:
- return fun(*args, **kwargs)
- finally:
- time.sleep = old_sleep
- return _sleepdeprived
- def skip_if_environ(env_var_name):
- def _wrap_test(fun):
- @wraps(fun)
- def _skips_if_environ(*args, **kwargs):
- if os.environ.get(env_var_name):
- sys.stderr.write("SKIP %s: %s set\n" % (
- fun.__name__, env_var_name))
- return
- return fun(*args, **kwargs)
- return _skips_if_environ
- return _wrap_test
- def skip_if_quick(fun):
- return skip_if_environ("QUICKTEST")(fun)
- def _skip_test(reason, sign):
- def _wrap_test(fun):
- @wraps(fun)
- def _skipped_test(*args, **kwargs):
- sys.stderr.write("%s: %s " % (sign, reason))
- return _skipped_test
- return _wrap_test
- def todo(reason):
- """TODO test decorator."""
- return _skip_test(reason, "TODO")
- def skip(reason):
- """Skip test decorator."""
- return _skip_test(reason, "SKIP")
- def skip_if(predicate, reason):
- """Skip test if predicate is ``True``."""
- def _inner(fun):
- return predicate and skip(reason)(fun) or fun
- return _inner
- def skip_unless(predicate, reason):
- """Skip test if predicate is ``False``."""
- return skip_if(not predicate, reason)
- # Taken from
- # http://bitbucket.org/runeh/snippets/src/tip/missing_modules.py
- @contextmanager
- def mask_modules(*modnames):
- """Ban some modules from being importable inside the context
- For example:
- >>> with missing_modules("sys"):
- ... try:
- ... import sys
- ... except ImportError:
- ... print "sys not found"
- sys not found
- >>> import sys
- >>> sys.version
- (2, 5, 2, 'final', 0)
- """
- realimport = __builtin__.__import__
- def myimp(name, *args, **kwargs):
- if name in modnames:
- raise ImportError("No module named %s" % name)
- else:
- return realimport(name, *args, **kwargs)
- __builtin__.__import__ = myimp
- yield True
- __builtin__.__import__ = realimport
- @contextmanager
- def override_stdouts():
- """Override ``sys.stdout`` and ``sys.stderr`` with ``StringIO``."""
- prev_out, prev_err = sys.stdout, sys.stderr
- mystdout, mystderr = StringIO(), StringIO()
- sys.stdout = sys.__stdout__ = mystdout
- sys.stderr = sys.__stderr__ = mystderr
- yield mystdout, mystderr
- sys.stdout = sys.__stdout__ = prev_out
- sys.stderr = sys.__stderr__ = prev_err
|