utils.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. from __future__ import generators
  2. import os
  3. import sys
  4. import __builtin__
  5. from StringIO import StringIO
  6. from billiard.utils.functional import wraps
  7. class GeneratorContextManager(object):
  8. def __init__(self, gen):
  9. self.gen = gen
  10. def __enter__(self):
  11. try:
  12. return self.gen.next()
  13. except StopIteration:
  14. raise RuntimeError("generator didn't yield")
  15. def __exit__(self, type, value, traceback):
  16. if type is None:
  17. try:
  18. self.gen.next()
  19. except StopIteration:
  20. return
  21. else:
  22. raise RuntimeError("generator didn't stop")
  23. else:
  24. try:
  25. self.gen.throw(type, value, traceback)
  26. raise RuntimeError("generator didn't stop after throw()")
  27. except StopIteration:
  28. return True
  29. except AttributeError:
  30. raise value
  31. except:
  32. if sys.exc_info()[1] is not value:
  33. raise
  34. def fallback_contextmanager(fun):
  35. def helper(*args, **kwds):
  36. return GeneratorContextManager(fun(*args, **kwds))
  37. return helper
  38. def execute_context(context, fun):
  39. val = context.__enter__()
  40. exc_info = (None, None, None)
  41. retval = None
  42. try:
  43. retval = fun(val)
  44. except:
  45. exc_info = sys.exc_info()
  46. context.__exit__(*exc_info)
  47. return retval
  48. try:
  49. from contextlib import contextmanager
  50. except ImportError:
  51. contextmanager = fallback_contextmanager
  52. from celery.utils import noop
  53. @contextmanager
  54. def eager_tasks():
  55. from celery import conf
  56. prev = conf.ALWAYS_EAGER
  57. conf.ALWAYS_EAGER = True
  58. yield True
  59. conf.ALWAYS_EAGER = prev
  60. def with_environ(env_name, env_value):
  61. def _envpatched(fun):
  62. @wraps(fun)
  63. def _patch_environ(*args, **kwargs):
  64. prev_val = os.environ.get(env_name)
  65. os.environ[env_name] = env_value
  66. try:
  67. return fun(*args, **kwargs)
  68. finally:
  69. if prev_val is not None:
  70. os.environ[env_name] = prev_val
  71. return _patch_environ
  72. return _envpatched
  73. def sleepdeprived(fun):
  74. @wraps(fun)
  75. def _sleepdeprived(*args, **kwargs):
  76. import time
  77. old_sleep = time.sleep
  78. time.sleep = noop
  79. try:
  80. return fun(*args, **kwargs)
  81. finally:
  82. time.sleep = old_sleep
  83. return _sleepdeprived
  84. def skip_if_environ(env_var_name):
  85. def _wrap_test(fun):
  86. @wraps(fun)
  87. def _skips_if_environ(*args, **kwargs):
  88. if os.environ.get(env_var_name):
  89. sys.stderr.write("SKIP %s: %s set\n" % (
  90. fun.__name__, env_var_name))
  91. return
  92. return fun(*args, **kwargs)
  93. return _skips_if_environ
  94. return _wrap_test
  95. def skip_if_quick(fun):
  96. return skip_if_environ("QUICKTEST")(fun)
  97. def _skip_test(reason, sign):
  98. def _wrap_test(fun):
  99. @wraps(fun)
  100. def _skipped_test(*args, **kwargs):
  101. sys.stderr.write("%s: %s " % (sign, reason))
  102. return _skipped_test
  103. return _wrap_test
  104. def todo(reason):
  105. """TODO test decorator."""
  106. return _skip_test(reason, "TODO")
  107. def skip(reason):
  108. """Skip test decorator."""
  109. return _skip_test(reason, "SKIP")
  110. def skip_if(predicate, reason):
  111. """Skip test if predicate is ``True``."""
  112. def _inner(fun):
  113. return predicate and skip(reason)(fun) or fun
  114. return _inner
  115. def skip_unless(predicate, reason):
  116. """Skip test if predicate is ``False``."""
  117. return skip_if(not predicate, reason)
  118. # Taken from
  119. # http://bitbucket.org/runeh/snippets/src/tip/missing_modules.py
  120. @contextmanager
  121. def mask_modules(*modnames):
  122. """Ban some modules from being importable inside the context
  123. For example:
  124. >>> with missing_modules("sys"):
  125. ... try:
  126. ... import sys
  127. ... except ImportError:
  128. ... print "sys not found"
  129. sys not found
  130. >>> import sys
  131. >>> sys.version
  132. (2, 5, 2, 'final', 0)
  133. """
  134. realimport = __builtin__.__import__
  135. def myimp(name, *args, **kwargs):
  136. if name in modnames:
  137. raise ImportError("No module named %s" % name)
  138. else:
  139. return realimport(name, *args, **kwargs)
  140. __builtin__.__import__ = myimp
  141. yield True
  142. __builtin__.__import__ = realimport
  143. @contextmanager
  144. def override_stdouts():
  145. """Override ``sys.stdout`` and ``sys.stderr`` with ``StringIO``."""
  146. prev_out, prev_err = sys.stdout, sys.stderr
  147. mystdout, mystderr = StringIO(), StringIO()
  148. sys.stdout = sys.__stdout__ = mystdout
  149. sys.stderr = sys.__stderr__ = mystderr
  150. yield mystdout, mystderr
  151. sys.stdout = sys.__stdout__ = prev_out
  152. sys.stderr = sys.__stderr__ = prev_err