utils.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576
  1. from __future__ import absolute_import
  2. from __future__ import with_statement
  3. try:
  4. import unittest
  5. unittest.skip
  6. from unittest.util import safe_repr, unorderable_list_difference
  7. except AttributeError:
  8. import unittest2 as unittest
  9. from unittest2.util import safe_repr, unorderable_list_difference # noqa
  10. import importlib
  11. import logging
  12. import os
  13. import platform
  14. import re
  15. import sys
  16. import time
  17. import warnings
  18. try:
  19. import __builtin__ as builtins
  20. except ImportError: # py3k
  21. import builtins # noqa
  22. from contextlib import contextmanager
  23. from functools import partial, wraps
  24. from types import ModuleType
  25. import mock
  26. from nose import SkipTest
  27. from kombu.log import NullHandler
  28. from kombu.utils import nested
  29. from celery.app import app_or_default
  30. from celery.utils.compat import WhateverIO
  31. from celery.utils.functional import noop
  32. from .compat import catch_warnings
  33. class Mock(mock.Mock):
  34. def __init__(self, *args, **kwargs):
  35. attrs = kwargs.pop('attrs', None) or {}
  36. super(Mock, self).__init__(*args, **kwargs)
  37. for attr_name, attr_value in attrs.items():
  38. setattr(self, attr_name, attr_value)
  39. def skip_unless_module(module):
  40. def _inner(fun):
  41. @wraps(fun)
  42. def __inner(*args, **kwargs):
  43. try:
  44. importlib.import_module(module)
  45. except ImportError:
  46. raise SkipTest('Does not have %s' % (module, ))
  47. return fun(*args, **kwargs)
  48. return __inner
  49. return _inner
  50. # -- adds assertWarns from recent unittest2, not in Python 2.7.
  51. class _AssertRaisesBaseContext(object):
  52. def __init__(self, expected, test_case, callable_obj=None,
  53. expected_regex=None):
  54. self.expected = expected
  55. self.failureException = test_case.failureException
  56. self.obj_name = None
  57. if isinstance(expected_regex, basestring):
  58. expected_regex = re.compile(expected_regex)
  59. self.expected_regex = expected_regex
  60. class _AssertWarnsContext(_AssertRaisesBaseContext):
  61. """A context manager used to implement TestCase.assertWarns* methods."""
  62. def __enter__(self):
  63. # The __warningregistry__'s need to be in a pristine state for tests
  64. # to work properly.
  65. warnings.resetwarnings()
  66. for v in sys.modules.values():
  67. if getattr(v, '__warningregistry__', None):
  68. v.__warningregistry__ = {}
  69. self.warnings_manager = catch_warnings(record=True)
  70. self.warnings = self.warnings_manager.__enter__()
  71. warnings.simplefilter('always', self.expected)
  72. return self
  73. def __exit__(self, exc_type, exc_value, tb):
  74. self.warnings_manager.__exit__(exc_type, exc_value, tb)
  75. if exc_type is not None:
  76. # let unexpected exceptions pass through
  77. return
  78. try:
  79. exc_name = self.expected.__name__
  80. except AttributeError:
  81. exc_name = str(self.expected)
  82. first_matching = None
  83. for m in self.warnings:
  84. w = m.message
  85. if not isinstance(w, self.expected):
  86. continue
  87. if first_matching is None:
  88. first_matching = w
  89. if (self.expected_regex is not None and
  90. not self.expected_regex.search(str(w))):
  91. continue
  92. # store warning for later retrieval
  93. self.warning = w
  94. self.filename = m.filename
  95. self.lineno = m.lineno
  96. return
  97. # Now we simply try to choose a helpful failure message
  98. if first_matching is not None:
  99. raise self.failureException('%r does not match %r' %
  100. (self.expected_regex.pattern, str(first_matching)))
  101. if self.obj_name:
  102. raise self.failureException('%s not triggered by %s'
  103. % (exc_name, self.obj_name))
  104. else:
  105. raise self.failureException('%s not triggered'
  106. % exc_name)
  107. class Case(unittest.TestCase):
  108. def assertWarns(self, expected_warning):
  109. return _AssertWarnsContext(expected_warning, self, None)
  110. def assertWarnsRegex(self, expected_warning, expected_regex):
  111. return _AssertWarnsContext(expected_warning, self,
  112. None, expected_regex)
  113. def assertDictContainsSubset(self, expected, actual, msg=None):
  114. missing, mismatched = [], []
  115. for key, value in expected.iteritems():
  116. if key not in actual:
  117. missing.append(key)
  118. elif value != actual[key]:
  119. mismatched.append('%s, expected: %s, actual: %s' % (
  120. safe_repr(key), safe_repr(value),
  121. safe_repr(actual[key])))
  122. if not (missing or mismatched):
  123. return
  124. standard_msg = ''
  125. if missing:
  126. standard_msg = 'Missing: %s' % ','.join(map(safe_repr, missing))
  127. if mismatched:
  128. if standard_msg:
  129. standard_msg += '; '
  130. standard_msg += 'Mismatched values: %s' % (
  131. ','.join(mismatched))
  132. self.fail(self._formatMessage(msg, standard_msg))
  133. def assertItemsEqual(self, expected_seq, actual_seq, msg=None):
  134. try:
  135. expected = sorted(expected_seq)
  136. actual = sorted(actual_seq)
  137. except TypeError:
  138. # Unsortable items (example: set(), complex(), ...)
  139. expected = list(expected_seq)
  140. actual = list(actual_seq)
  141. missing, unexpected = unorderable_list_difference(
  142. expected, actual)
  143. else:
  144. return self.assertSequenceEqual(expected, actual, msg=msg)
  145. errors = []
  146. if missing:
  147. errors.append('Expected, but missing:\n %s' % (
  148. safe_repr(missing)))
  149. if unexpected:
  150. errors.append('Unexpected, but present:\n %s' % (
  151. safe_repr(unexpected)))
  152. if errors:
  153. standardMsg = '\n'.join(errors)
  154. self.fail(self._formatMessage(msg, standardMsg))
  155. class AppCase(Case):
  156. def setUp(self):
  157. from celery.app import current_app
  158. from celery.backends.cache import CacheBackend, DummyClient
  159. app = self.app = self._current_app = current_app()
  160. if isinstance(app.backend, CacheBackend):
  161. if isinstance(app.backend.client, DummyClient):
  162. app.backend.client.cache.clear()
  163. app.backend._cache.clear()
  164. self.setup()
  165. def tearDown(self):
  166. self.teardown()
  167. self._current_app.set_current()
  168. def setup(self):
  169. pass
  170. def teardown(self):
  171. pass
  172. def get_handlers(logger):
  173. return [h for h in logger.handlers if not isinstance(h, NullHandler)]
  174. @contextmanager
  175. def wrap_logger(logger, loglevel=logging.ERROR):
  176. old_handlers = get_handlers(logger)
  177. sio = WhateverIO()
  178. siohandler = logging.StreamHandler(sio)
  179. logger.handlers = [siohandler]
  180. yield sio
  181. logger.handlers = old_handlers
  182. @contextmanager
  183. def eager_tasks():
  184. app = app_or_default()
  185. prev = app.conf.CELERY_ALWAYS_EAGER
  186. app.conf.CELERY_ALWAYS_EAGER = True
  187. yield True
  188. app.conf.CELERY_ALWAYS_EAGER = prev
  189. def with_eager_tasks(fun):
  190. @wraps(fun)
  191. def _inner(*args, **kwargs):
  192. app = app_or_default()
  193. prev = app.conf.CELERY_ALWAYS_EAGER
  194. app.conf.CELERY_ALWAYS_EAGER = True
  195. try:
  196. return fun(*args, **kwargs)
  197. finally:
  198. app.conf.CELERY_ALWAYS_EAGER = prev
  199. def with_environ(env_name, env_value):
  200. def _envpatched(fun):
  201. @wraps(fun)
  202. def _patch_environ(*args, **kwargs):
  203. prev_val = os.environ.get(env_name)
  204. os.environ[env_name] = env_value
  205. try:
  206. return fun(*args, **kwargs)
  207. finally:
  208. if prev_val is not None:
  209. os.environ[env_name] = prev_val
  210. return _patch_environ
  211. return _envpatched
  212. def sleepdeprived(module=time):
  213. def _sleepdeprived(fun):
  214. @wraps(fun)
  215. def __sleepdeprived(*args, **kwargs):
  216. old_sleep = module.sleep
  217. module.sleep = noop
  218. try:
  219. return fun(*args, **kwargs)
  220. finally:
  221. module.sleep = old_sleep
  222. return __sleepdeprived
  223. return _sleepdeprived
  224. def skip_if_environ(env_var_name):
  225. def _wrap_test(fun):
  226. @wraps(fun)
  227. def _skips_if_environ(*args, **kwargs):
  228. if os.environ.get(env_var_name):
  229. raise SkipTest('SKIP %s: %s set\n' % (
  230. fun.__name__, env_var_name))
  231. return fun(*args, **kwargs)
  232. return _skips_if_environ
  233. return _wrap_test
  234. def skip_if_quick(fun):
  235. return skip_if_environ('QUICKTEST')(fun)
  236. def _skip_test(reason, sign):
  237. def _wrap_test(fun):
  238. @wraps(fun)
  239. def _skipped_test(*args, **kwargs):
  240. raise SkipTest('%s: %s' % (sign, reason))
  241. return _skipped_test
  242. return _wrap_test
  243. def todo(reason):
  244. """TODO test decorator."""
  245. return _skip_test(reason, 'TODO')
  246. def skip(reason):
  247. """Skip test decorator."""
  248. return _skip_test(reason, 'SKIP')
  249. def skip_if(predicate, reason):
  250. """Skip test if predicate is :const:`True`."""
  251. def _inner(fun):
  252. return predicate and skip(reason)(fun) or fun
  253. return _inner
  254. def skip_unless(predicate, reason):
  255. """Skip test if predicate is :const:`False`."""
  256. return skip_if(not predicate, reason)
  257. # Taken from
  258. # http://bitbucket.org/runeh/snippets/src/tip/missing_modules.py
  259. @contextmanager
  260. def mask_modules(*modnames):
  261. """Ban some modules from being importable inside the context
  262. For example:
  263. >>> with missing_modules('sys'):
  264. ... try:
  265. ... import sys
  266. ... except ImportError:
  267. ... print 'sys not found'
  268. sys not found
  269. >>> import sys
  270. >>> sys.version
  271. (2, 5, 2, 'final', 0)
  272. """
  273. realimport = builtins.__import__
  274. def myimp(name, *args, **kwargs):
  275. if name in modnames:
  276. raise ImportError('No module named %s' % name)
  277. else:
  278. return realimport(name, *args, **kwargs)
  279. builtins.__import__ = myimp
  280. yield True
  281. builtins.__import__ = realimport
  282. @contextmanager
  283. def override_stdouts():
  284. """Override `sys.stdout` and `sys.stderr` with `WhateverIO`."""
  285. prev_out, prev_err = sys.stdout, sys.stderr
  286. mystdout, mystderr = WhateverIO(), WhateverIO()
  287. sys.stdout = sys.__stdout__ = mystdout
  288. sys.stderr = sys.__stderr__ = mystderr
  289. yield mystdout, mystderr
  290. sys.stdout = sys.__stdout__ = prev_out
  291. sys.stderr = sys.__stderr__ = prev_err
  292. def patch(module, name, mocked):
  293. module = importlib.import_module(module)
  294. def _patch(fun):
  295. @wraps(fun)
  296. def __patched(*args, **kwargs):
  297. prev = getattr(module, name)
  298. setattr(module, name, mocked)
  299. try:
  300. return fun(*args, **kwargs)
  301. finally:
  302. setattr(module, name, prev)
  303. return __patched
  304. return _patch
  305. @contextmanager
  306. def replace_module_value(module, name, value=None):
  307. has_prev = hasattr(module, name)
  308. prev = getattr(module, name, None)
  309. if value:
  310. setattr(module, name, value)
  311. else:
  312. try:
  313. delattr(module, name)
  314. except AttributeError:
  315. pass
  316. yield
  317. if prev is not None:
  318. setattr(sys, name, prev)
  319. if not has_prev:
  320. try:
  321. delattr(module, name)
  322. except AttributeError:
  323. pass
  324. pypy_version = partial(
  325. replace_module_value, sys, 'pypy_version_info',
  326. )
  327. platform_pyimp = partial(
  328. replace_module_value, platform, 'python_implementation',
  329. )
  330. @contextmanager
  331. def sys_platform(value):
  332. prev, sys.platform = sys.platform, value
  333. yield
  334. sys.platform = prev
  335. @contextmanager
  336. def reset_modules(*modules):
  337. prev = dict((k, sys.modules.pop(k)) for k in modules if k in sys.modules)
  338. yield
  339. sys.modules.update(prev)
  340. @contextmanager
  341. def patch_modules(*modules):
  342. prev = {}
  343. for mod in modules:
  344. prev[mod], sys.modules[mod] = sys.modules[mod], ModuleType(mod)
  345. yield
  346. for name, mod in prev.iteritems():
  347. if mod is None:
  348. sys.modules.pop(name, None)
  349. else:
  350. sys.modules[name] = mod
  351. @contextmanager
  352. def mock_module(*names):
  353. prev = {}
  354. class MockModule(ModuleType):
  355. def __getattr__(self, attr):
  356. setattr(self, attr, Mock())
  357. return ModuleType.__getattribute__(self, attr)
  358. mods = []
  359. for name in names:
  360. try:
  361. prev[name] = sys.modules[name]
  362. except KeyError:
  363. pass
  364. mod = sys.modules[name] = MockModule(name)
  365. mods.append(mod)
  366. try:
  367. yield mods
  368. finally:
  369. for name in names:
  370. try:
  371. sys.modules[name] = prev[name]
  372. except KeyError:
  373. try:
  374. del(sys.modules[name])
  375. except KeyError:
  376. pass
  377. @contextmanager
  378. def mock_context(mock, typ=Mock):
  379. context = mock.return_value = Mock()
  380. context.__enter__ = typ()
  381. context.__exit__ = typ()
  382. def on_exit(*x):
  383. if x[0]:
  384. raise x[0], x[1], x[2]
  385. context.__exit__.side_effect = on_exit
  386. context.__enter__.return_value = context
  387. yield context
  388. context.reset()
  389. @contextmanager
  390. def mock_open(typ=WhateverIO, side_effect=None):
  391. with mock.patch('__builtin__.open') as open_:
  392. with mock_context(open_) as context:
  393. if side_effect is not None:
  394. context.__enter__.side_effect = side_effect
  395. val = context.__enter__.return_value = typ()
  396. val.__exit__ = Mock()
  397. yield val
  398. def patch_many(*targets):
  399. return nested(*[mock.patch(target) for target in targets])
  400. @contextmanager
  401. def patch_settings(app=None, **config):
  402. if app is None:
  403. from celery import current_app
  404. app = current_app
  405. prev = {}
  406. for key, value in config.iteritems():
  407. try:
  408. prev[key] = getattr(app.conf, key)
  409. except AttributeError:
  410. pass
  411. setattr(app.conf, key, value)
  412. yield app.conf
  413. for key, value in prev.iteritems():
  414. setattr(app.conf, key, value)
  415. @contextmanager
  416. def assert_signal_called(signal, **expected):
  417. handler = Mock()
  418. call_handler = partial(handler)
  419. signal.connect(call_handler)
  420. try:
  421. yield handler
  422. finally:
  423. signal.disconnect(call_handler)
  424. handler.assert_called_with(signal=signal, **expected)
  425. def skip_if_pypy(fun):
  426. @wraps(fun)
  427. def _inner(*args, **kwargs):
  428. if getattr(sys, 'pypy_version_info', None):
  429. raise SkipTest('does not work on PyPy')
  430. return fun(*args, **kwargs)
  431. return _inner
  432. def skip_if_jython(fun):
  433. @wraps(fun)
  434. def _inner(*args, **kwargs):
  435. if sys.platform.startswith('java'):
  436. raise SkipTest('does not work on Jython')
  437. return fun(*args, **kwargs)
  438. return _inner