utils.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. from __future__ import generators
  2. try:
  3. import unittest
  4. unittest.skip
  5. except AttributeError:
  6. import unittest2 as unittest
  7. import os
  8. import sys
  9. import time
  10. try:
  11. import __builtin__ as builtins
  12. except ImportError: # py3k
  13. import builtins
  14. from celery.utils.compat import StringIO
  15. from nose import SkipTest
  16. from celery.app import app_or_default
  17. from celery.utils.functional import wraps
  18. class GeneratorContextManager(object):
  19. def __init__(self, gen):
  20. self.gen = gen
  21. def __enter__(self):
  22. try:
  23. return self.gen.next()
  24. except StopIteration:
  25. raise RuntimeError("generator didn't yield")
  26. def __exit__(self, type, value, traceback):
  27. if type is None:
  28. try:
  29. self.gen.next()
  30. except StopIteration:
  31. return
  32. else:
  33. raise RuntimeError("generator didn't stop")
  34. else:
  35. try:
  36. self.gen.throw(type, value, traceback)
  37. raise RuntimeError("generator didn't stop after throw()")
  38. except StopIteration:
  39. return True
  40. except AttributeError:
  41. raise value
  42. except:
  43. if sys.exc_info()[1] is not value:
  44. raise
  45. def fallback_contextmanager(fun):
  46. def helper(*args, **kwds):
  47. return GeneratorContextManager(fun(*args, **kwds))
  48. return helper
  49. def execute_context(context, fun):
  50. val = context.__enter__()
  51. exc_info = (None, None, None)
  52. retval = None
  53. try:
  54. try:
  55. return fun(val)
  56. except:
  57. exc_info = sys.exc_info()
  58. raise
  59. finally:
  60. context.__exit__(*exc_info)
  61. try:
  62. from contextlib import contextmanager
  63. except ImportError:
  64. contextmanager = fallback_contextmanager
  65. from celery.utils import noop
  66. @contextmanager
  67. def eager_tasks():
  68. app = app_or_default()
  69. prev = app.conf.CELERY_ALWAYS_EAGER
  70. app.conf.CELERY_ALWAYS_EAGER = True
  71. yield True
  72. app.conf.CELERY_ALWAYS_EAGER = prev
  73. def with_eager_tasks(fun):
  74. @wraps(fun)
  75. def _inner(*args, **kwargs):
  76. app = app_or_default()
  77. prev = app.conf.CELERY_ALWAYS_EAGER
  78. app.conf.CELERY_ALWAYS_EAGER = True
  79. try:
  80. return fun(*args, **kwargs)
  81. finally:
  82. app.conf.CELERY_ALWAYS_EAGER = prev
  83. def with_environ(env_name, env_value):
  84. def _envpatched(fun):
  85. @wraps(fun)
  86. def _patch_environ(*args, **kwargs):
  87. prev_val = os.environ.get(env_name)
  88. os.environ[env_name] = env_value
  89. try:
  90. return fun(*args, **kwargs)
  91. finally:
  92. if prev_val is not None:
  93. os.environ[env_name] = prev_val
  94. return _patch_environ
  95. return _envpatched
  96. def sleepdeprived(module=time):
  97. def _sleepdeprived(fun):
  98. @wraps(fun)
  99. def __sleepdeprived(*args, **kwargs):
  100. old_sleep = module.sleep
  101. module.sleep = noop
  102. try:
  103. return fun(*args, **kwargs)
  104. finally:
  105. module.sleep = old_sleep
  106. return __sleepdeprived
  107. return _sleepdeprived
  108. def skip_if_environ(env_var_name):
  109. def _wrap_test(fun):
  110. @wraps(fun)
  111. def _skips_if_environ(*args, **kwargs):
  112. if os.environ.get(env_var_name):
  113. raise SkipTest("SKIP %s: %s set\n" % (
  114. fun.__name__, env_var_name))
  115. return fun(*args, **kwargs)
  116. return _skips_if_environ
  117. return _wrap_test
  118. def skip_if_quick(fun):
  119. return skip_if_environ("QUICKTEST")(fun)
  120. def _skip_test(reason, sign):
  121. def _wrap_test(fun):
  122. @wraps(fun)
  123. def _skipped_test(*args, **kwargs):
  124. raise SkipTest("%s: %s" % (sign, reason))
  125. return _skipped_test
  126. return _wrap_test
  127. def todo(reason):
  128. """TODO test decorator."""
  129. return _skip_test(reason, "TODO")
  130. def skip(reason):
  131. """Skip test decorator."""
  132. return _skip_test(reason, "SKIP")
  133. def skip_if(predicate, reason):
  134. """Skip test if predicate is :const:`True`."""
  135. def _inner(fun):
  136. return predicate and skip(reason)(fun) or fun
  137. return _inner
  138. def skip_unless(predicate, reason):
  139. """Skip test if predicate is :const:`False`."""
  140. return skip_if(not predicate, reason)
  141. # Taken from
  142. # http://bitbucket.org/runeh/snippets/src/tip/missing_modules.py
  143. @contextmanager
  144. def mask_modules(*modnames):
  145. """Ban some modules from being importable inside the context
  146. For example:
  147. >>> with missing_modules("sys"):
  148. ... try:
  149. ... import sys
  150. ... except ImportError:
  151. ... print "sys not found"
  152. sys not found
  153. >>> import sys
  154. >>> sys.version
  155. (2, 5, 2, 'final', 0)
  156. """
  157. realimport = builtins.__import__
  158. def myimp(name, *args, **kwargs):
  159. if name in modnames:
  160. raise ImportError("No module named %s" % name)
  161. else:
  162. return realimport(name, *args, **kwargs)
  163. builtins.__import__ = myimp
  164. yield True
  165. builtins.__import__ = realimport
  166. @contextmanager
  167. def override_stdouts():
  168. """Override `sys.stdout` and `sys.stderr` with `StringIO`."""
  169. prev_out, prev_err = sys.stdout, sys.stderr
  170. mystdout, mystderr = StringIO(), StringIO()
  171. sys.stdout = sys.__stdout__ = mystdout
  172. sys.stderr = sys.__stderr__ = mystderr
  173. yield mystdout, mystderr
  174. sys.stdout = sys.__stdout__ = prev_out
  175. sys.stderr = sys.__stderr__ = prev_err