utils.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. from __future__ import absolute_import
  2. try:
  3. import unittest
  4. unittest.skip
  5. except AttributeError:
  6. import unittest2 as unittest
  7. import importlib
  8. import logging
  9. import os
  10. import sys
  11. import time
  12. try:
  13. import __builtin__ as builtins
  14. except ImportError: # py3k
  15. import builtins # noqa
  16. from functools import wraps
  17. from contextlib import contextmanager
  18. import mock
  19. from nose import SkipTest
  20. from celery.app import app_or_default
  21. from celery.utils import noop
  22. from celery.utils.compat import StringIO, LoggerAdapter
  23. class Mock(mock.Mock):
  24. def __init__(self, *args, **kwargs):
  25. attrs = kwargs.pop("attrs", None) or {}
  26. super(Mock, self).__init__(*args, **kwargs)
  27. for attr_name, attr_value in attrs.items():
  28. setattr(self, attr_name, attr_value)
  29. class AppCase(unittest.TestCase):
  30. def setUp(self):
  31. from celery.app import current_app
  32. self.app = self._current_app = current_app()
  33. self.setup()
  34. def tearDown(self):
  35. self.teardown()
  36. self._current_app.set_current()
  37. def setup(self):
  38. pass
  39. def teardown(self):
  40. pass
  41. def get_handlers(logger):
  42. if isinstance(logger, LoggerAdapter):
  43. return logger.logger.handlers
  44. return logger.handlers
  45. def set_handlers(logger, new_handlers):
  46. if isinstance(logger, LoggerAdapter):
  47. logger.logger.handlers = new_handlers
  48. logger.handlers = new_handlers
  49. @contextmanager
  50. def wrap_logger(logger, loglevel=logging.ERROR):
  51. old_handlers = get_handlers(logger)
  52. sio = StringIO()
  53. siohandler = logging.StreamHandler(sio)
  54. set_handlers(logger, [siohandler])
  55. yield sio
  56. set_handlers(logger, old_handlers)
  57. @contextmanager
  58. def eager_tasks():
  59. app = app_or_default()
  60. prev = app.conf.CELERY_ALWAYS_EAGER
  61. app.conf.CELERY_ALWAYS_EAGER = True
  62. yield True
  63. app.conf.CELERY_ALWAYS_EAGER = prev
  64. def with_eager_tasks(fun):
  65. @wraps(fun)
  66. def _inner(*args, **kwargs):
  67. app = app_or_default()
  68. prev = app.conf.CELERY_ALWAYS_EAGER
  69. app.conf.CELERY_ALWAYS_EAGER = True
  70. try:
  71. return fun(*args, **kwargs)
  72. finally:
  73. app.conf.CELERY_ALWAYS_EAGER = prev
  74. def with_environ(env_name, env_value):
  75. def _envpatched(fun):
  76. @wraps(fun)
  77. def _patch_environ(*args, **kwargs):
  78. prev_val = os.environ.get(env_name)
  79. os.environ[env_name] = env_value
  80. try:
  81. return fun(*args, **kwargs)
  82. finally:
  83. if prev_val is not None:
  84. os.environ[env_name] = prev_val
  85. return _patch_environ
  86. return _envpatched
  87. def sleepdeprived(module=time):
  88. def _sleepdeprived(fun):
  89. @wraps(fun)
  90. def __sleepdeprived(*args, **kwargs):
  91. old_sleep = module.sleep
  92. module.sleep = noop
  93. try:
  94. return fun(*args, **kwargs)
  95. finally:
  96. module.sleep = old_sleep
  97. return __sleepdeprived
  98. return _sleepdeprived
  99. def skip_if_environ(env_var_name):
  100. def _wrap_test(fun):
  101. @wraps(fun)
  102. def _skips_if_environ(*args, **kwargs):
  103. if os.environ.get(env_var_name):
  104. raise SkipTest("SKIP %s: %s set\n" % (
  105. fun.__name__, env_var_name))
  106. return fun(*args, **kwargs)
  107. return _skips_if_environ
  108. return _wrap_test
  109. def skip_if_quick(fun):
  110. return skip_if_environ("QUICKTEST")(fun)
  111. def _skip_test(reason, sign):
  112. def _wrap_test(fun):
  113. @wraps(fun)
  114. def _skipped_test(*args, **kwargs):
  115. raise SkipTest("%s: %s" % (sign, reason))
  116. return _skipped_test
  117. return _wrap_test
  118. def todo(reason):
  119. """TODO test decorator."""
  120. return _skip_test(reason, "TODO")
  121. def skip(reason):
  122. """Skip test decorator."""
  123. return _skip_test(reason, "SKIP")
  124. def skip_if(predicate, reason):
  125. """Skip test if predicate is :const:`True`."""
  126. def _inner(fun):
  127. return predicate and skip(reason)(fun) or fun
  128. return _inner
  129. def skip_unless(predicate, reason):
  130. """Skip test if predicate is :const:`False`."""
  131. return skip_if(not predicate, reason)
  132. # Taken from
  133. # http://bitbucket.org/runeh/snippets/src/tip/missing_modules.py
  134. @contextmanager
  135. def mask_modules(*modnames):
  136. """Ban some modules from being importable inside the context
  137. For example:
  138. >>> with missing_modules("sys"):
  139. ... try:
  140. ... import sys
  141. ... except ImportError:
  142. ... print "sys not found"
  143. sys not found
  144. >>> import sys
  145. >>> sys.version
  146. (2, 5, 2, 'final', 0)
  147. """
  148. realimport = builtins.__import__
  149. def myimp(name, *args, **kwargs):
  150. if name in modnames:
  151. raise ImportError("No module named %s" % name)
  152. else:
  153. return realimport(name, *args, **kwargs)
  154. builtins.__import__ = myimp
  155. yield True
  156. builtins.__import__ = realimport
  157. @contextmanager
  158. def override_stdouts():
  159. """Override `sys.stdout` and `sys.stderr` with `StringIO`."""
  160. prev_out, prev_err = sys.stdout, sys.stderr
  161. mystdout, mystderr = StringIO(), StringIO()
  162. sys.stdout = sys.__stdout__ = mystdout
  163. sys.stderr = sys.__stderr__ = mystderr
  164. yield mystdout, mystderr
  165. sys.stdout = sys.__stdout__ = prev_out
  166. sys.stderr = sys.__stderr__ = prev_err
  167. def patch(module, name, mocked):
  168. module = importlib.import_module(module)
  169. def _patch(fun):
  170. @wraps(fun)
  171. def __patched(*args, **kwargs):
  172. prev = getattr(module, name)
  173. setattr(module, name, mocked)
  174. try:
  175. return fun(*args, **kwargs)
  176. finally:
  177. setattr(module, name, prev)
  178. return __patched
  179. return _patch
  180. @contextmanager
  181. def platform_pyimp(replace=None):
  182. import platform
  183. prev = getattr(platform, "python_implementation", None)
  184. if replace:
  185. platform.python_implementation = replace
  186. else:
  187. try:
  188. delattr(platform, "python_implementation")
  189. except AttributeError:
  190. pass
  191. yield
  192. if prev is not None:
  193. platform.python_implementation = prev
  194. @contextmanager
  195. def sys_platform(value):
  196. prev, sys.platform = sys.platform, value
  197. yield
  198. sys.platform = prev
  199. @contextmanager
  200. def pypy_version(value=None):
  201. prev = getattr(sys, "pypy_version_info", None)
  202. if value:
  203. sys.pypy_version_info = value
  204. else:
  205. try:
  206. delattr(sys, "pypy_version_info")
  207. except AttributeError:
  208. pass
  209. yield
  210. if prev is not None:
  211. sys.pypy_version_info = prev
  212. @contextmanager
  213. def reset_modules(*modules):
  214. prev = dict((k, sys.modules.pop(k)) for k in modules if k in sys.modules)
  215. yield
  216. sys.modules.update(prev)
  217. @contextmanager
  218. def patch_modules(*modules):
  219. from types import ModuleType
  220. prev = {}
  221. for mod in modules:
  222. prev[mod], sys.modules[mod] = sys.modules[mod], ModuleType(mod)
  223. yield
  224. for name, mod in prev.iteritems():
  225. if mod is None:
  226. sys.modules.pop(name, None)
  227. else:
  228. sys.modules[name] = mod