utils.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. from __future__ import generators
  2. import os
  3. import sys
  4. import __builtin__
  5. from StringIO import StringIO
  6. from nose import SkipTest
  7. from celery.app import default_app
  8. from celery.utils.functional import wraps
  9. class GeneratorContextManager(object):
  10. def __init__(self, gen):
  11. self.gen = gen
  12. def __enter__(self):
  13. try:
  14. return self.gen.next()
  15. except StopIteration:
  16. raise RuntimeError("generator didn't yield")
  17. def __exit__(self, type, value, traceback):
  18. if type is None:
  19. try:
  20. self.gen.next()
  21. except StopIteration:
  22. return
  23. else:
  24. raise RuntimeError("generator didn't stop")
  25. else:
  26. try:
  27. self.gen.throw(type, value, traceback)
  28. raise RuntimeError("generator didn't stop after throw()")
  29. except StopIteration:
  30. return True
  31. except AttributeError:
  32. raise value
  33. except:
  34. if sys.exc_info()[1] is not value:
  35. raise
  36. def fallback_contextmanager(fun):
  37. def helper(*args, **kwds):
  38. return GeneratorContextManager(fun(*args, **kwds))
  39. return helper
  40. def execute_context(context, fun):
  41. val = context.__enter__()
  42. exc_info = (None, None, None)
  43. retval = None
  44. try:
  45. retval = fun(val)
  46. except:
  47. exc_info = sys.exc_info()
  48. context.__exit__(*exc_info)
  49. return retval
  50. try:
  51. from contextlib import contextmanager
  52. except ImportError:
  53. contextmanager = fallback_contextmanager
  54. from celery.utils import noop
  55. @contextmanager
  56. def eager_tasks():
  57. prev = default_app.conf.CELERY_ALWAYS_EAGER
  58. default_app.conf.CELERY_ALWAYS_EAGER = True
  59. yield True
  60. default_app.conf.CELERY_ALWAYS_EAGER = prev
  61. def with_eager_tasks(fun):
  62. @wraps(fun)
  63. def _inner(*args, **kwargs):
  64. prev = default_app.conf.CELERY_ALWAYS_EAGER
  65. default_app.conf.CELERY_ALWAYS_EAGER = True
  66. try:
  67. return fun(*args, **kwargs)
  68. finally:
  69. default_app.conf.CELERY_ALWAYS_EAGER = prev
  70. def with_environ(env_name, env_value):
  71. def _envpatched(fun):
  72. @wraps(fun)
  73. def _patch_environ(*args, **kwargs):
  74. prev_val = os.environ.get(env_name)
  75. os.environ[env_name] = env_value
  76. try:
  77. return fun(*args, **kwargs)
  78. finally:
  79. if prev_val is not None:
  80. os.environ[env_name] = prev_val
  81. return _patch_environ
  82. return _envpatched
  83. def sleepdeprived(fun):
  84. @wraps(fun)
  85. def _sleepdeprived(*args, **kwargs):
  86. import time
  87. old_sleep = time.sleep
  88. time.sleep = noop
  89. try:
  90. return fun(*args, **kwargs)
  91. finally:
  92. time.sleep = old_sleep
  93. return _sleepdeprived
  94. def skip_if_environ(env_var_name):
  95. def _wrap_test(fun):
  96. @wraps(fun)
  97. def _skips_if_environ(*args, **kwargs):
  98. if os.environ.get(env_var_name):
  99. raise SkipTest("SKIP %s: %s set\n" % (
  100. fun.__name__, env_var_name))
  101. return fun(*args, **kwargs)
  102. return _skips_if_environ
  103. return _wrap_test
  104. def skip_if_quick(fun):
  105. return skip_if_environ("QUICKTEST")(fun)
  106. def _skip_test(reason, sign):
  107. def _wrap_test(fun):
  108. @wraps(fun)
  109. def _skipped_test(*args, **kwargs):
  110. raise SkipTest("%s: %s" % (sign, reason))
  111. return _skipped_test
  112. return _wrap_test
  113. def todo(reason):
  114. """TODO test decorator."""
  115. return _skip_test(reason, "TODO")
  116. def skip(reason):
  117. """Skip test decorator."""
  118. return _skip_test(reason, "SKIP")
  119. def skip_if(predicate, reason):
  120. """Skip test if predicate is ``True``."""
  121. def _inner(fun):
  122. return predicate and skip(reason)(fun) or fun
  123. return _inner
  124. def skip_unless(predicate, reason):
  125. """Skip test if predicate is ``False``."""
  126. return skip_if(not predicate, reason)
  127. # Taken from
  128. # http://bitbucket.org/runeh/snippets/src/tip/missing_modules.py
  129. @contextmanager
  130. def mask_modules(*modnames):
  131. """Ban some modules from being importable inside the context
  132. For example:
  133. >>> with missing_modules("sys"):
  134. ... try:
  135. ... import sys
  136. ... except ImportError:
  137. ... print "sys not found"
  138. sys not found
  139. >>> import sys
  140. >>> sys.version
  141. (2, 5, 2, 'final', 0)
  142. """
  143. realimport = __builtin__.__import__
  144. def myimp(name, *args, **kwargs):
  145. if name in modnames:
  146. raise ImportError("No module named %s" % name)
  147. else:
  148. return realimport(name, *args, **kwargs)
  149. __builtin__.__import__ = myimp
  150. yield True
  151. __builtin__.__import__ = realimport
  152. @contextmanager
  153. def override_stdouts():
  154. """Override ``sys.stdout`` and ``sys.stderr`` with ``StringIO``."""
  155. prev_out, prev_err = sys.stdout, sys.stderr
  156. mystdout, mystderr = StringIO(), StringIO()
  157. sys.stdout = sys.__stdout__ = mystdout
  158. sys.stderr = sys.__stderr__ = mystderr
  159. yield mystdout, mystderr
  160. sys.stdout = sys.__stdout__ = prev_out
  161. sys.stderr = sys.__stderr__ = prev_err