utils.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. from __future__ import absolute_import
  2. try:
  3. import unittest
  4. unittest.skip
  5. except AttributeError:
  6. import unittest2 as unittest
  7. import importlib
  8. import logging
  9. import os
  10. import sys
  11. import time
  12. try:
  13. import __builtin__ as builtins
  14. except ImportError: # py3k
  15. import builtins # noqa
  16. from functools import wraps
  17. from contextlib import contextmanager
  18. import mock
  19. from nose import SkipTest
  20. from celery.app import app_or_default
  21. from celery.utils import noop
  22. from celery.utils.compat import StringIO, LoggerAdapter
  23. class Mock(mock.Mock):
  24. def __init__(self, *args, **kwargs):
  25. attrs = kwargs.pop("attrs", None) or {}
  26. super(Mock, self).__init__(*args, **kwargs)
  27. for attr_name, attr_value in attrs.items():
  28. setattr(self, attr_name, attr_value)
  29. def skip_unless_module(module):
  30. def _inner(fun):
  31. @wraps(fun)
  32. def __inner(*args, **kwargs):
  33. try:
  34. importlib.import_module(module)
  35. except ImportError:
  36. raise SkipTest("Does not have %s" % (module, ))
  37. return fun(*args, **kwargs)
  38. return __inner
  39. return _inner
  40. class AppCase(unittest.TestCase):
  41. def setUp(self):
  42. from celery.app import current_app
  43. self.app = self._current_app = current_app()
  44. self.setup()
  45. def tearDown(self):
  46. self.teardown()
  47. self._current_app.set_current()
  48. def setup(self):
  49. pass
  50. def teardown(self):
  51. pass
  52. def get_handlers(logger):
  53. if isinstance(logger, LoggerAdapter):
  54. return logger.logger.handlers
  55. return logger.handlers
  56. def set_handlers(logger, new_handlers):
  57. if isinstance(logger, LoggerAdapter):
  58. logger.logger.handlers = new_handlers
  59. logger.handlers = new_handlers
  60. @contextmanager
  61. def wrap_logger(logger, loglevel=logging.ERROR):
  62. old_handlers = get_handlers(logger)
  63. sio = StringIO()
  64. siohandler = logging.StreamHandler(sio)
  65. set_handlers(logger, [siohandler])
  66. yield sio
  67. set_handlers(logger, old_handlers)
  68. @contextmanager
  69. def eager_tasks():
  70. app = app_or_default()
  71. prev = app.conf.CELERY_ALWAYS_EAGER
  72. app.conf.CELERY_ALWAYS_EAGER = True
  73. yield True
  74. app.conf.CELERY_ALWAYS_EAGER = prev
  75. def with_eager_tasks(fun):
  76. @wraps(fun)
  77. def _inner(*args, **kwargs):
  78. app = app_or_default()
  79. prev = app.conf.CELERY_ALWAYS_EAGER
  80. app.conf.CELERY_ALWAYS_EAGER = True
  81. try:
  82. return fun(*args, **kwargs)
  83. finally:
  84. app.conf.CELERY_ALWAYS_EAGER = prev
  85. def with_environ(env_name, env_value):
  86. def _envpatched(fun):
  87. @wraps(fun)
  88. def _patch_environ(*args, **kwargs):
  89. prev_val = os.environ.get(env_name)
  90. os.environ[env_name] = env_value
  91. try:
  92. return fun(*args, **kwargs)
  93. finally:
  94. if prev_val is not None:
  95. os.environ[env_name] = prev_val
  96. return _patch_environ
  97. return _envpatched
  98. def sleepdeprived(module=time):
  99. def _sleepdeprived(fun):
  100. @wraps(fun)
  101. def __sleepdeprived(*args, **kwargs):
  102. old_sleep = module.sleep
  103. module.sleep = noop
  104. try:
  105. return fun(*args, **kwargs)
  106. finally:
  107. module.sleep = old_sleep
  108. return __sleepdeprived
  109. return _sleepdeprived
  110. def skip_if_environ(env_var_name):
  111. def _wrap_test(fun):
  112. @wraps(fun)
  113. def _skips_if_environ(*args, **kwargs):
  114. if os.environ.get(env_var_name):
  115. raise SkipTest("SKIP %s: %s set\n" % (
  116. fun.__name__, env_var_name))
  117. return fun(*args, **kwargs)
  118. return _skips_if_environ
  119. return _wrap_test
  120. def skip_if_quick(fun):
  121. return skip_if_environ("QUICKTEST")(fun)
  122. def _skip_test(reason, sign):
  123. def _wrap_test(fun):
  124. @wraps(fun)
  125. def _skipped_test(*args, **kwargs):
  126. raise SkipTest("%s: %s" % (sign, reason))
  127. return _skipped_test
  128. return _wrap_test
  129. def todo(reason):
  130. """TODO test decorator."""
  131. return _skip_test(reason, "TODO")
  132. def skip(reason):
  133. """Skip test decorator."""
  134. return _skip_test(reason, "SKIP")
  135. def skip_if(predicate, reason):
  136. """Skip test if predicate is :const:`True`."""
  137. def _inner(fun):
  138. return predicate and skip(reason)(fun) or fun
  139. return _inner
  140. def skip_unless(predicate, reason):
  141. """Skip test if predicate is :const:`False`."""
  142. return skip_if(not predicate, reason)
  143. # Taken from
  144. # http://bitbucket.org/runeh/snippets/src/tip/missing_modules.py
  145. @contextmanager
  146. def mask_modules(*modnames):
  147. """Ban some modules from being importable inside the context
  148. For example:
  149. >>> with missing_modules("sys"):
  150. ... try:
  151. ... import sys
  152. ... except ImportError:
  153. ... print "sys not found"
  154. sys not found
  155. >>> import sys
  156. >>> sys.version
  157. (2, 5, 2, 'final', 0)
  158. """
  159. realimport = builtins.__import__
  160. def myimp(name, *args, **kwargs):
  161. if name in modnames:
  162. raise ImportError("No module named %s" % name)
  163. else:
  164. return realimport(name, *args, **kwargs)
  165. builtins.__import__ = myimp
  166. yield True
  167. builtins.__import__ = realimport
  168. @contextmanager
  169. def override_stdouts():
  170. """Override `sys.stdout` and `sys.stderr` with `StringIO`."""
  171. prev_out, prev_err = sys.stdout, sys.stderr
  172. mystdout, mystderr = StringIO(), StringIO()
  173. sys.stdout = sys.__stdout__ = mystdout
  174. sys.stderr = sys.__stderr__ = mystderr
  175. yield mystdout, mystderr
  176. sys.stdout = sys.__stdout__ = prev_out
  177. sys.stderr = sys.__stderr__ = prev_err
  178. def patch(module, name, mocked):
  179. module = importlib.import_module(module)
  180. def _patch(fun):
  181. @wraps(fun)
  182. def __patched(*args, **kwargs):
  183. prev = getattr(module, name)
  184. setattr(module, name, mocked)
  185. try:
  186. return fun(*args, **kwargs)
  187. finally:
  188. setattr(module, name, prev)
  189. return __patched
  190. return _patch
  191. @contextmanager
  192. def platform_pyimp(replace=None):
  193. import platform
  194. prev = getattr(platform, "python_implementation", None)
  195. if replace:
  196. platform.python_implementation = replace
  197. else:
  198. try:
  199. delattr(platform, "python_implementation")
  200. except AttributeError:
  201. pass
  202. yield
  203. if prev is not None:
  204. platform.python_implementation = prev
  205. @contextmanager
  206. def sys_platform(value):
  207. prev, sys.platform = sys.platform, value
  208. yield
  209. sys.platform = prev
  210. @contextmanager
  211. def pypy_version(value=None):
  212. prev = getattr(sys, "pypy_version_info", None)
  213. if value:
  214. sys.pypy_version_info = value
  215. else:
  216. try:
  217. delattr(sys, "pypy_version_info")
  218. except AttributeError:
  219. pass
  220. yield
  221. if prev is not None:
  222. sys.pypy_version_info = prev
  223. @contextmanager
  224. def reset_modules(*modules):
  225. prev = dict((k, sys.modules.pop(k)) for k in modules if k in sys.modules)
  226. yield
  227. sys.modules.update(prev)
  228. @contextmanager
  229. def patch_modules(*modules):
  230. from types import ModuleType
  231. prev = {}
  232. for mod in modules:
  233. prev[mod], sys.modules[mod] = sys.modules[mod], ModuleType(mod)
  234. yield
  235. for name, mod in prev.iteritems():
  236. if mod is None:
  237. sys.modules.pop(name, None)
  238. else:
  239. sys.modules[name] = mod