utils.py 15 KB


  1. from __future__ import absolute_import
  2. try:
  3. import unittest
  4. unittest.skip
  5. from unittest.util import safe_repr, unorderable_list_difference
  6. except AttributeError:
  7. import unittest2 as unittest
  8. from unittest2.util import safe_repr, unorderable_list_difference # noqa
  9. import importlib
  10. import logging
  11. import os
  12. import platform
  13. import re
  14. import sys
  15. import time
  16. import warnings
  17. try:
  18. import __builtin__ as builtins
  19. except ImportError: # py3k
  20. import builtins # noqa
  21. from contextlib import contextmanager
  22. from functools import partial, wraps
  23. from types import ModuleType
  24. import mock
  25. from nose import SkipTest
  26. from kombu.log import NullHandler
  27. from kombu.utils import nested
  28. from ..app import app_or_default
  29. from ..utils.compat import WhateverIO
  30. from ..utils.functional import noop
  31. class Mock(mock.Mock):
  32. def __init__(self, *args, **kwargs):
  33. attrs = kwargs.pop('attrs', None) or {}
  34. super(Mock, self).__init__(*args, **kwargs)
  35. for attr_name, attr_value in attrs.items():
  36. setattr(self, attr_name, attr_value)
  37. def skip_unless_module(module):
  38. def _inner(fun):
  39. @wraps(fun)
  40. def __inner(*args, **kwargs):
  41. try:
  42. importlib.import_module(module)
  43. except ImportError:
  44. raise SkipTest('Does not have %s' % (module, ))
  45. return fun(*args, **kwargs)
  46. return __inner
  47. return _inner
  48. # -- adds assertWarns from recent unittest2, not in Python 2.7.
  49. class _AssertRaisesBaseContext(object):
  50. def __init__(self, expected, test_case, callable_obj=None,
  51. expected_regex=None):
  52. self.expected = expected
  53. self.failureException = test_case.failureException
  54. self.obj_name = None
  55. if isinstance(expected_regex, basestring):
  56. expected_regex = re.compile(expected_regex)
  57. self.expected_regex = expected_regex
  58. class _AssertWarnsContext(_AssertRaisesBaseContext):
  59. """A context manager used to implement TestCase.assertWarns* methods."""
  60. def __enter__(self):
  61. # The __warningregistry__'s need to be in a pristine state for tests
  62. # to work properly.
  63. warnings.resetwarnings()
  64. for v in sys.modules.values():
  65. if getattr(v, '__warningregistry__', None):
  66. v.__warningregistry__ = {}
  67. self.warnings_manager = warnings.catch_warnings(record=True)
  68. self.warnings = self.warnings_manager.__enter__()
  69. warnings.simplefilter('always', self.expected)
  70. return self
  71. def __exit__(self, exc_type, exc_value, tb):
  72. self.warnings_manager.__exit__(exc_type, exc_value, tb)
  73. if exc_type is not None:
  74. # let unexpected exceptions pass through
  75. return
  76. try:
  77. exc_name = self.expected.__name__
  78. except AttributeError:
  79. exc_name = str(self.expected)
  80. first_matching = None
  81. for m in self.warnings:
  82. w = m.message
  83. if not isinstance(w, self.expected):
  84. continue
  85. if first_matching is None:
  86. first_matching = w
  87. if (self.expected_regex is not None and
  88. not self.expected_regex.search(str(w))):
  89. continue
  90. # store warning for later retrieval
  91. self.warning = w
  92. self.filename = m.filename
  93. self.lineno = m.lineno
  94. return
  95. # Now we simply try to choose a helpful failure message
  96. if first_matching is not None:
  97. raise self.failureException('%r does not match %r' %
  98. (self.expected_regex.pattern, str(first_matching)))
  99. if self.obj_name:
  100. raise self.failureException('%s not triggered by %s'
  101. % (exc_name, self.obj_name))
  102. else:
  103. raise self.failureException('%s not triggered'
  104. % exc_name)
  105. class Case(unittest.TestCase):
  106. def assertWarns(self, expected_warning):
  107. return _AssertWarnsContext(expected_warning, self, None)
  108. def assertWarnsRegex(self, expected_warning, expected_regex):
  109. return _AssertWarnsContext(expected_warning, self,
  110. None, expected_regex)
  111. def assertDictContainsSubset(self, expected, actual, msg=None):
  112. missing, mismatched = [], []
  113. for key, value in expected.iteritems():
  114. if key not in actual:
  115. missing.append(key)
  116. elif value != actual[key]:
  117. mismatched.append('%s, expected: %s, actual: %s' % (
  118. safe_repr(key), safe_repr(value),
  119. safe_repr(actual[key])))
  120. if not (missing or mismatched):
  121. return
  122. standard_msg = ''
  123. if missing:
  124. standard_msg = 'Missing: %s' % ','.join(map(safe_repr, missing))
  125. if mismatched:
  126. if standard_msg:
  127. standard_msg += '; '
  128. standard_msg += 'Mismatched values: %s' % (
  129. ','.join(mismatched))
  130. self.fail(self._formatMessage(msg, standard_msg))
  131. def assertItemsEqual(self, expected_seq, actual_seq, msg=None):
  132. try:
  133. expected = sorted(expected_seq)
  134. actual = sorted(actual_seq)
  135. except TypeError:
  136. # Unsortable items (example: set(), complex(), ...)
  137. expected = list(expected_seq)
  138. actual = list(actual_seq)
  139. missing, unexpected = unorderable_list_difference(
  140. expected, actual)
  141. else:
  142. return self.assertSequenceEqual(expected, actual, msg=msg)
  143. errors = []
  144. if missing:
  145. errors.append('Expected, but missing:\n %s' % (
  146. safe_repr(missing)))
  147. if unexpected:
  148. errors.append('Unexpected, but present:\n %s' % (
  149. safe_repr(unexpected)))
  150. if errors:
  151. standardMsg = '\n'.join(errors)
  152. self.fail(self._formatMessage(msg, standardMsg))
  153. class AppCase(Case):
  154. def setUp(self):
  155. from ..app import current_app
  156. from ..backends.cache import CacheBackend, DummyClient
  157. app = self.app = self._current_app = current_app()
  158. if isinstance(app.backend, CacheBackend):
  159. if isinstance(app.backend.client, DummyClient):
  160. app.backend.client.cache.clear()
  161. app.backend._cache.clear()
  162. self.setup()
  163. def tearDown(self):
  164. self.teardown()
  165. self._current_app.set_current()
  166. def setup(self):
  167. pass
  168. def teardown(self):
  169. pass
  170. def get_handlers(logger):
  171. return [h for h in logger.handlers if not isinstance(h, NullHandler)]
  172. @contextmanager
  173. def wrap_logger(logger, loglevel=logging.ERROR):
  174. old_handlers = get_handlers(logger)
  175. sio = WhateverIO()
  176. siohandler = logging.StreamHandler(sio)
  177. logger.handlers = [siohandler]
  178. yield sio
  179. logger.handlers = old_handlers
  180. @contextmanager
  181. def eager_tasks():
  182. app = app_or_default()
  183. prev = app.conf.CELERY_ALWAYS_EAGER
  184. app.conf.CELERY_ALWAYS_EAGER = True
  185. yield True
  186. app.conf.CELERY_ALWAYS_EAGER = prev
  187. def with_eager_tasks(fun):
  188. @wraps(fun)
  189. def _inner(*args, **kwargs):
  190. app = app_or_default()
  191. prev = app.conf.CELERY_ALWAYS_EAGER
  192. app.conf.CELERY_ALWAYS_EAGER = True
  193. try:
  194. return fun(*args, **kwargs)
  195. finally:
  196. app.conf.CELERY_ALWAYS_EAGER = prev
  197. def with_environ(env_name, env_value):
  198. def _envpatched(fun):
  199. @wraps(fun)
  200. def _patch_environ(*args, **kwargs):
  201. prev_val = os.environ.get(env_name)
  202. os.environ[env_name] = env_value
  203. try:
  204. return fun(*args, **kwargs)
  205. finally:
  206. if prev_val is not None:
  207. os.environ[env_name] = prev_val
  208. return _patch_environ
  209. return _envpatched
  210. def sleepdeprived(module=time):
  211. def _sleepdeprived(fun):
  212. @wraps(fun)
  213. def __sleepdeprived(*args, **kwargs):
  214. old_sleep = module.sleep
  215. module.sleep = noop
  216. try:
  217. return fun(*args, **kwargs)
  218. finally:
  219. module.sleep = old_sleep
  220. return __sleepdeprived
  221. return _sleepdeprived
  222. def skip_if_environ(env_var_name):
  223. def _wrap_test(fun):
  224. @wraps(fun)
  225. def _skips_if_environ(*args, **kwargs):
  226. if os.environ.get(env_var_name):
  227. raise SkipTest('SKIP %s: %s set\n' % (
  228. fun.__name__, env_var_name))
  229. return fun(*args, **kwargs)
  230. return _skips_if_environ
  231. return _wrap_test
  232. def skip_if_quick(fun):
  233. return skip_if_environ('QUICKTEST')(fun)
  234. def _skip_test(reason, sign):
  235. def _wrap_test(fun):
  236. @wraps(fun)
  237. def _skipped_test(*args, **kwargs):
  238. raise SkipTest('%s: %s' % (sign, reason))
  239. return _skipped_test
  240. return _wrap_test
  241. def todo(reason):
  242. """TODO test decorator."""
  243. return _skip_test(reason, 'TODO')
  244. def skip(reason):
  245. """Skip test decorator."""
  246. return _skip_test(reason, 'SKIP')
  247. def skip_if(predicate, reason):
  248. """Skip test if predicate is :const:`True`."""
  249. def _inner(fun):
  250. return predicate and skip(reason)(fun) or fun
  251. return _inner
  252. def skip_unless(predicate, reason):
  253. """Skip test if predicate is :const:`False`."""
  254. return skip_if(not predicate, reason)
  255. # Taken from
  256. # http://bitbucket.org/runeh/snippets/src/tip/missing_modules.py
  257. @contextmanager
  258. def mask_modules(*modnames):
  259. """Ban some modules from being importable inside the context
  260. For example:
  261. >>> with missing_modules('sys'):
  262. ... try:
  263. ... import sys
  264. ... except ImportError:
  265. ... print 'sys not found'
  266. sys not found
  267. >>> import sys
  268. >>> sys.version
  269. (2, 5, 2, 'final', 0)
  270. """
  271. realimport = builtins.__import__
  272. def myimp(name, *args, **kwargs):
  273. if name in modnames:
  274. raise ImportError('No module named %s' % name)
  275. else:
  276. return realimport(name, *args, **kwargs)
  277. builtins.__import__ = myimp
  278. yield True
  279. builtins.__import__ = realimport
  280. @contextmanager
  281. def override_stdouts():
  282. """Override `sys.stdout` and `sys.stderr` with `WhateverIO`."""
  283. prev_out, prev_err = sys.stdout, sys.stderr
  284. mystdout, mystderr = WhateverIO(), WhateverIO()
  285. sys.stdout = sys.__stdout__ = mystdout
  286. sys.stderr = sys.__stderr__ = mystderr
  287. yield mystdout, mystderr
  288. sys.stdout = sys.__stdout__ = prev_out
  289. sys.stderr = sys.__stderr__ = prev_err
  290. def patch(module, name, mocked):
  291. module = importlib.import_module(module)
  292. def _patch(fun):
  293. @wraps(fun)
  294. def __patched(*args, **kwargs):
  295. prev = getattr(module, name)
  296. setattr(module, name, mocked)
  297. try:
  298. return fun(*args, **kwargs)
  299. finally:
  300. setattr(module, name, prev)
  301. return __patched
  302. return _patch
  303. @contextmanager
  304. def replace_module_value(module, name, value=None):
  305. has_prev = hasattr(module, name)
  306. prev = getattr(module, name, None)
  307. if value:
  308. setattr(module, name, value)
  309. else:
  310. try:
  311. delattr(module, name)
  312. except AttributeError:
  313. pass
  314. yield
  315. if prev is not None:
  316. setattr(sys, name, prev)
  317. if not has_prev:
  318. try:
  319. delattr(module, name)
  320. except AttributeError:
  321. pass
  322. pypy_version = partial(
  323. replace_module_value, sys, 'pypy_version_info',
  324. )
  325. platform_pyimp = partial(
  326. replace_module_value, platform, 'python_implementation',
  327. )
  328. @contextmanager
  329. def sys_platform(value):
  330. prev, sys.platform = sys.platform, value
  331. yield
  332. sys.platform = prev
  333. @contextmanager
  334. def reset_modules(*modules):
  335. prev = dict((k, sys.modules.pop(k)) for k in modules if k in sys.modules)
  336. yield
  337. sys.modules.update(prev)
  338. @contextmanager
  339. def patch_modules(*modules):
  340. prev = {}
  341. for mod in modules:
  342. prev[mod], sys.modules[mod] = sys.modules[mod], ModuleType(mod)
  343. yield
  344. for name, mod in prev.iteritems():
  345. if mod is None:
  346. sys.modules.pop(name, None)
  347. else:
  348. sys.modules[name] = mod
  349. @contextmanager
  350. def mock_module(*names):
  351. prev = {}
  352. class MockModule(ModuleType):
  353. def __getattr__(self, attr):
  354. setattr(self, attr, Mock())
  355. return ModuleType.__getattribute__(self, attr)
  356. mods = []
  357. for name in names:
  358. try:
  359. prev[name] = sys.modules[name]
  360. except KeyError:
  361. pass
  362. mod = sys.modules[name] = MockModule(name)
  363. mods.append(mod)
  364. try:
  365. yield mods
  366. finally:
  367. for name in names:
  368. try:
  369. sys.modules[name] = prev[name]
  370. except KeyError:
  371. try:
  372. del(sys.modules[name])
  373. except KeyError:
  374. pass
  375. @contextmanager
  376. def mock_context(mock, typ=Mock):
  377. context = mock.return_value = Mock()
  378. context.__enter__ = typ()
  379. context.__exit__ = typ()
  380. def on_exit(*x):
  381. if x[0]:
  382. raise x[0], x[1], x[2]
  383. context.__exit__.side_effect = on_exit
  384. context.__enter__.return_value = context
  385. yield context
  386. context.reset()
  387. @contextmanager
  388. def mock_open(typ=WhateverIO, side_effect=None):
  389. with mock.patch('__builtin__.open') as open_:
  390. with mock_context(open_) as context:
  391. if side_effect is not None:
  392. context.__enter__.side_effect = side_effect
  393. val = context.__enter__.return_value = typ()
  394. yield val
  395. def patch_many(*targets):
  396. return nested(*[mock.patch(target) for target in targets])
  397. @contextmanager
  398. def patch_settings(app=None, **config):
  399. if app is None:
  400. from celery import current_app
  401. app = current_app
  402. prev = {}
  403. for key, value in config.iteritems():
  404. try:
  405. prev[key] = getattr(app.conf, key)
  406. except AttributeError:
  407. pass
  408. setattr(app.conf, key, value)
  409. yield app.conf
  410. for key, value in prev.iteritems():
  411. setattr(app.conf, key, value)
  412. @contextmanager
  413. def assert_signal_called(signal, **expected):
  414. handler = Mock()
  415. call_handler = partial(handler)
  416. signal.connect(call_handler)
  417. try:
  418. yield handler
  419. finally:
  420. signal.disconnect(call_handler)
  421. handler.assert_called_with(signal=signal, **expected)
  422. def skip_if_pypy(fun):
  423. @wraps(fun)
  424. def _inner(*args, **kwargs):
  425. if getattr(sys, 'pypy_version_info', None):
  426. raise SkipTest('does not work on PyPy')
  427. return fun(*args, **kwargs)
  428. return _inner
  429. def skip_if_jython(fun):
  430. @wraps(fun)
  431. def _inner(*args, **kwargs):
  432. if sys.platform.startswith('java'):
  433. raise SkipTest('does not work on Jython')
  434. return fun(*args, **kwargs)
  435. return _inner