utils.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. from __future__ import generators
  2. try:
  3. import unittest
  4. unittest.skip
  5. except AttributeError:
  6. import unittest2 as unittest
  7. import importlib
  8. import logging
  9. import os
  10. import sys
  11. import time
  12. try:
  13. import __builtin__ as builtins
  14. except ImportError: # py3k
  15. import builtins
  16. from celery.utils.compat import StringIO, LoggerAdapter
  17. try:
  18. from contextlib import contextmanager
  19. except ImportError:
  20. from celery.tests.utils import fallback_contextmanager as contextmanager
  21. from nose import SkipTest
  22. from celery.app import app_or_default
  23. from celery.utils.functional import wraps
  24. class AppCase(unittest.TestCase):
  25. def setUp(self):
  26. from celery.app import current_app
  27. self._current_app = current_app()
  28. self.setup()
  29. def tearDown(self):
  30. self.teardown()
  31. self._current_app.set_current()
  32. def setup(self):
  33. pass
  34. def teardown(self):
  35. pass
  36. class GeneratorContextManager(object):
  37. def __init__(self, gen):
  38. self.gen = gen
  39. def __enter__(self):
  40. try:
  41. return self.gen.next()
  42. except StopIteration:
  43. raise RuntimeError("generator didn't yield")
  44. def __exit__(self, type, value, traceback):
  45. if type is None:
  46. try:
  47. self.gen.next()
  48. except StopIteration:
  49. return
  50. else:
  51. raise RuntimeError("generator didn't stop")
  52. else:
  53. try:
  54. self.gen.throw(type, value, traceback)
  55. raise RuntimeError("generator didn't stop after throw()")
  56. except StopIteration:
  57. return True
  58. except AttributeError:
  59. raise value
  60. except:
  61. if sys.exc_info()[1] is not value:
  62. raise
  63. def get_handlers(logger):
  64. if isinstance(logger, LoggerAdapter):
  65. return logger.logger.handlers
  66. return logger.handlers
  67. def set_handlers(logger, new_handlers):
  68. if isinstance(logger, LoggerAdapter):
  69. logger.logger.handlers = new_handlers
  70. logger.handlers = new_handlers
  71. @contextmanager
  72. def wrap_logger(logger, loglevel=logging.ERROR):
  73. old_handlers = get_handlers(logger)
  74. sio = StringIO()
  75. siohandler = logging.StreamHandler(sio)
  76. set_handlers(logger, [siohandler])
  77. yield sio
  78. set_handlers(logger, old_handlers)
  79. def fallback_contextmanager(fun):
  80. def helper(*args, **kwds):
  81. return GeneratorContextManager(fun(*args, **kwds))
  82. return helper
  83. def execute_context(context, fun):
  84. val = context.__enter__()
  85. exc_info = (None, None, None)
  86. try:
  87. try:
  88. return fun(val)
  89. except:
  90. exc_info = sys.exc_info()
  91. raise
  92. finally:
  93. context.__exit__(*exc_info)
  94. try:
  95. from contextlib import contextmanager
  96. except ImportError:
  97. contextmanager = fallback_contextmanager
  98. from celery.utils import noop
  99. @contextmanager
  100. def eager_tasks():
  101. app = app_or_default()
  102. prev = app.conf.CELERY_ALWAYS_EAGER
  103. app.conf.CELERY_ALWAYS_EAGER = True
  104. yield True
  105. app.conf.CELERY_ALWAYS_EAGER = prev
  106. def with_eager_tasks(fun):
  107. @wraps(fun)
  108. def _inner(*args, **kwargs):
  109. app = app_or_default()
  110. prev = app.conf.CELERY_ALWAYS_EAGER
  111. app.conf.CELERY_ALWAYS_EAGER = True
  112. try:
  113. return fun(*args, **kwargs)
  114. finally:
  115. app.conf.CELERY_ALWAYS_EAGER = prev
  116. def with_environ(env_name, env_value):
  117. def _envpatched(fun):
  118. @wraps(fun)
  119. def _patch_environ(*args, **kwargs):
  120. prev_val = os.environ.get(env_name)
  121. os.environ[env_name] = env_value
  122. try:
  123. return fun(*args, **kwargs)
  124. finally:
  125. if prev_val is not None:
  126. os.environ[env_name] = prev_val
  127. return _patch_environ
  128. return _envpatched
  129. def sleepdeprived(module=time):
  130. def _sleepdeprived(fun):
  131. @wraps(fun)
  132. def __sleepdeprived(*args, **kwargs):
  133. old_sleep = module.sleep
  134. module.sleep = noop
  135. try:
  136. return fun(*args, **kwargs)
  137. finally:
  138. module.sleep = old_sleep
  139. return __sleepdeprived
  140. return _sleepdeprived
  141. def skip_if_environ(env_var_name):
  142. def _wrap_test(fun):
  143. @wraps(fun)
  144. def _skips_if_environ(*args, **kwargs):
  145. if os.environ.get(env_var_name):
  146. raise SkipTest("SKIP %s: %s set\n" % (
  147. fun.__name__, env_var_name))
  148. return fun(*args, **kwargs)
  149. return _skips_if_environ
  150. return _wrap_test
  151. def skip_if_quick(fun):
  152. return skip_if_environ("QUICKTEST")(fun)
  153. def _skip_test(reason, sign):
  154. def _wrap_test(fun):
  155. @wraps(fun)
  156. def _skipped_test(*args, **kwargs):
  157. raise SkipTest("%s: %s" % (sign, reason))
  158. return _skipped_test
  159. return _wrap_test
  160. def todo(reason):
  161. """TODO test decorator."""
  162. return _skip_test(reason, "TODO")
  163. def skip(reason):
  164. """Skip test decorator."""
  165. return _skip_test(reason, "SKIP")
  166. def skip_if(predicate, reason):
  167. """Skip test if predicate is :const:`True`."""
  168. def _inner(fun):
  169. return predicate and skip(reason)(fun) or fun
  170. return _inner
  171. def skip_unless(predicate, reason):
  172. """Skip test if predicate is :const:`False`."""
  173. return skip_if(not predicate, reason)
  174. # Taken from
  175. # http://bitbucket.org/runeh/snippets/src/tip/missing_modules.py
  176. @contextmanager
  177. def mask_modules(*modnames):
  178. """Ban some modules from being importable inside the context
  179. For example:
  180. >>> with missing_modules("sys"):
  181. ... try:
  182. ... import sys
  183. ... except ImportError:
  184. ... print "sys not found"
  185. sys not found
  186. >>> import sys
  187. >>> sys.version
  188. (2, 5, 2, 'final', 0)
  189. """
  190. realimport = builtins.__import__
  191. def myimp(name, *args, **kwargs):
  192. if name in modnames:
  193. raise ImportError("No module named %s" % name)
  194. else:
  195. return realimport(name, *args, **kwargs)
  196. builtins.__import__ = myimp
  197. yield True
  198. builtins.__import__ = realimport
  199. @contextmanager
  200. def override_stdouts():
  201. """Override `sys.stdout` and `sys.stderr` with `StringIO`."""
  202. prev_out, prev_err = sys.stdout, sys.stderr
  203. mystdout, mystderr = StringIO(), StringIO()
  204. sys.stdout = sys.__stdout__ = mystdout
  205. sys.stderr = sys.__stderr__ = mystderr
  206. yield mystdout, mystderr
  207. sys.stdout = sys.__stdout__ = prev_out
  208. sys.stderr = sys.__stderr__ = prev_err
  209. def patch(module, name, mocked):
  210. module = importlib.import_module(module)
  211. def _patch(fun):
  212. @wraps(fun)
  213. def __patched(*args, **kwargs):
  214. prev = getattr(module, name)
  215. setattr(module, name, mocked)
  216. try:
  217. return fun(*args, **kwargs)
  218. finally:
  219. setattr(module, name, prev)
  220. return __patched
  221. return _patch