case.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863
  1. from __future__ import absolute_import
  2. try:
  3. import unittest # noqa
  4. unittest.skip
  5. from unittest.util import safe_repr, unorderable_list_difference
  6. except AttributeError:
  7. import unittest2 as unittest # noqa
  8. from unittest2.util import safe_repr, unorderable_list_difference # noqa
  9. import importlib
  10. import inspect
  11. import logging
  12. import numbers
  13. import os
  14. import platform
  15. import re
  16. import sys
  17. import threading
  18. import time
  19. import types
  20. import warnings
  21. from contextlib import contextmanager
  22. from copy import deepcopy
  23. from datetime import datetime, timedelta
  24. from functools import partial, wraps
  25. from types import ModuleType
  26. try:
  27. from unittest import mock
  28. except ImportError:
  29. import mock # noqa
  30. from nose import SkipTest
  31. from kombu import Queue
  32. from kombu.log import NullHandler
  33. from kombu.utils import nested, symbol_by_name
  34. from celery import Celery
  35. from celery.app import current_app
  36. from celery.backends.cache import CacheBackend, DummyClient
  37. from celery.exceptions import CDeprecationWarning, CPendingDeprecationWarning
  38. from celery.five import (
  39. WhateverIO, builtins, items, reraise,
  40. string_t, values, open_fqdn,
  41. )
  42. from celery.utils.functional import noop
  43. from celery.utils.imports import qualname
  44. __all__ = [
  45. 'Case', 'AppCase', 'Mock', 'MagicMock', 'ANY',
  46. 'patch', 'call', 'sentinel', 'skip_unless_module',
  47. 'wrap_logger', 'with_environ', 'sleepdeprived',
  48. 'skip_if_environ', 'todo', 'skip', 'skip_if',
  49. 'skip_unless', 'mask_modules', 'override_stdouts', 'mock_module',
  50. 'replace_module_value', 'sys_platform', 'reset_modules',
  51. 'patch_modules', 'mock_context', 'mock_open', 'patch_many',
  52. 'assert_signal_called', 'skip_if_pypy',
  53. 'skip_if_jython', 'body_from_sig', 'restore_logging',
  54. ]
  55. patch = mock.patch
  56. call = mock.call
  57. sentinel = mock.sentinel
  58. MagicMock = mock.MagicMock
  59. ANY = mock.ANY
  60. PY3 = sys.version_info[0] == 3
  61. CASE_REDEFINES_SETUP = """\
  62. {name} (subclass of AppCase) redefines private "setUp", should be: "setup"\
  63. """
  64. CASE_REDEFINES_TEARDOWN = """\
  65. {name} (subclass of AppCase) redefines private "tearDown", \
  66. should be: "teardown"\
  67. """
  68. CASE_LOG_REDIRECT_EFFECT = """\
  69. Test {0} did not disable LoggingProxy for {1}\
  70. """
  71. CASE_LOG_LEVEL_EFFECT = """\
  72. Test {0} Modified the level of the root logger\
  73. """
  74. CASE_LOG_HANDLER_EFFECT = """\
  75. Test {0} Modified handlers for the root logger\
  76. """
  77. CELERY_TEST_CONFIG = {
  78. #: Don't want log output when running suite.
  79. 'CELERYD_HIJACK_ROOT_LOGGER': False,
  80. 'CELERY_SEND_TASK_ERROR_EMAILS': False,
  81. 'CELERY_DEFAULT_QUEUE': 'testcelery',
  82. 'CELERY_DEFAULT_EXCHANGE': 'testcelery',
  83. 'CELERY_DEFAULT_ROUTING_KEY': 'testcelery',
  84. 'CELERY_QUEUES': (
  85. Queue('testcelery', routing_key='testcelery'),
  86. ),
  87. 'CELERY_ENABLE_UTC': True,
  88. 'CELERY_TIMEZONE': 'UTC',
  89. 'CELERYD_LOG_COLOR': False,
  90. # Mongo results tests (only executed if installed and running)
  91. 'CELERY_MONGODB_BACKEND_SETTINGS': {
  92. 'host': os.environ.get('MONGO_HOST') or 'localhost',
  93. 'port': os.environ.get('MONGO_PORT') or 27017,
  94. 'database': os.environ.get('MONGO_DB') or 'celery_unittests',
  95. 'taskmeta_collection': (os.environ.get('MONGO_TASKMETA_COLLECTION')
  96. or 'taskmeta_collection'),
  97. 'user': os.environ.get('MONGO_USER'),
  98. 'password': os.environ.get('MONGO_PASSWORD'),
  99. }
  100. }
  101. class Trap(object):
  102. def __getattr__(self, name):
  103. raise RuntimeError('Test depends on current_app')
  104. class UnitLogging(symbol_by_name(Celery.log_cls)):
  105. def __init__(self, *args, **kwargs):
  106. super(UnitLogging, self).__init__(*args, **kwargs)
  107. self.already_setup = True
  108. def UnitApp(name=None, broker=None, backend=None,
  109. set_as_current=False, log=UnitLogging, **kwargs):
  110. app = Celery(name or 'celery.tests',
  111. broker=broker or 'memory://',
  112. backend=backend or 'cache+memory://',
  113. set_as_current=set_as_current,
  114. log=log,
  115. **kwargs)
  116. app.add_defaults(deepcopy(CELERY_TEST_CONFIG))
  117. return app
  118. class Mock(mock.Mock):
  119. def __init__(self, *args, **kwargs):
  120. attrs = kwargs.pop('attrs', None) or {}
  121. super(Mock, self).__init__(*args, **kwargs)
  122. for attr_name, attr_value in items(attrs):
  123. setattr(self, attr_name, attr_value)
  124. class _ContextMock(Mock):
  125. """Dummy class implementing __enter__ and __exit__
  126. as the with statement requires these to be implemented
  127. in the class, not just the instance."""
  128. def __enter__(self):
  129. pass
  130. def __exit__(self, *exc_info):
  131. pass
  132. def ContextMock(*args, **kwargs):
  133. obj = _ContextMock(*args, **kwargs)
  134. obj.attach_mock(_ContextMock(), '__enter__')
  135. obj.attach_mock(_ContextMock(), '__exit__')
  136. obj.__enter__.return_value = obj
  137. # if __exit__ return a value the exception is ignored,
  138. # so it must return None here.
  139. obj.__exit__.return_value = None
  140. return obj
  141. def _bind(f, o):
  142. @wraps(f)
  143. def bound_meth(*fargs, **fkwargs):
  144. return f(o, *fargs, **fkwargs)
  145. return bound_meth
  146. if PY3: # pragma: no cover
  147. def _get_class_fun(meth):
  148. return meth
  149. else:
  150. def _get_class_fun(meth):
  151. return meth.__func__
  152. class MockCallbacks(object):
  153. def __new__(cls, *args, **kwargs):
  154. r = Mock(name=cls.__name__)
  155. _get_class_fun(cls.__init__)(r, *args, **kwargs)
  156. for key, value in items(vars(cls)):
  157. if key not in ('__dict__', '__weakref__', '__new__', '__init__'):
  158. if inspect.ismethod(value) or inspect.isfunction(value):
  159. r.__getattr__(key).side_effect = _bind(value, r)
  160. else:
  161. r.__setattr__(key, value)
  162. return r
  163. def skip_unless_module(module):
  164. def _inner(fun):
  165. @wraps(fun)
  166. def __inner(*args, **kwargs):
  167. try:
  168. importlib.import_module(module)
  169. except ImportError:
  170. raise SkipTest('Does not have %s' % (module, ))
  171. return fun(*args, **kwargs)
  172. return __inner
  173. return _inner
  174. # -- adds assertWarns from recent unittest2, not in Python 2.7.
  175. class _AssertRaisesBaseContext(object):
  176. def __init__(self, expected, test_case, callable_obj=None,
  177. expected_regex=None):
  178. self.expected = expected
  179. self.failureException = test_case.failureException
  180. self.obj_name = None
  181. if isinstance(expected_regex, string_t):
  182. expected_regex = re.compile(expected_regex)
  183. self.expected_regex = expected_regex
  184. def _is_magic_module(m):
  185. # some libraries create custom module types that are lazily
  186. # lodaded, e.g. Django installs some modules in sys.modules that
  187. # will load _tkinter and other shit when touched.
  188. # pyflakes refuses to accept 'noqa' for this isinstance.
  189. cls, modtype = m.__class__, types.ModuleType
  190. return (not cls is modtype and (
  191. '__getattr__' in vars(m.__class__) or
  192. '__getattribute__' in vars(m.__class__)))
  193. class _AssertWarnsContext(_AssertRaisesBaseContext):
  194. """A context manager used to implement TestCase.assertWarns* methods."""
  195. def __enter__(self):
  196. # The __warningregistry__'s need to be in a pristine state for tests
  197. # to work properly.
  198. warnings.resetwarnings()
  199. for v in list(values(sys.modules)):
  200. # do not evaluate Django moved modules and other lazily
  201. # initialized modules.
  202. if v and not _is_magic_module(v):
  203. # use raw __getattribute__ to protect even better from
  204. # lazily loaded modules
  205. try:
  206. object.__getattribute__(v, '__warningregistry__')
  207. except AttributeError:
  208. pass
  209. else:
  210. object.__setattr__(v, '__warningregistry__', {})
  211. self.warnings_manager = warnings.catch_warnings(record=True)
  212. self.warnings = self.warnings_manager.__enter__()
  213. warnings.simplefilter('always', self.expected)
  214. return self
  215. def __exit__(self, exc_type, exc_value, tb):
  216. self.warnings_manager.__exit__(exc_type, exc_value, tb)
  217. if exc_type is not None:
  218. # let unexpected exceptions pass through
  219. return
  220. try:
  221. exc_name = self.expected.__name__
  222. except AttributeError:
  223. exc_name = str(self.expected)
  224. first_matching = None
  225. for m in self.warnings:
  226. w = m.message
  227. if not isinstance(w, self.expected):
  228. continue
  229. if first_matching is None:
  230. first_matching = w
  231. if (self.expected_regex is not None and
  232. not self.expected_regex.search(str(w))):
  233. continue
  234. # store warning for later retrieval
  235. self.warning = w
  236. self.filename = m.filename
  237. self.lineno = m.lineno
  238. return
  239. # Now we simply try to choose a helpful failure message
  240. if first_matching is not None:
  241. raise self.failureException(
  242. '%r does not match %r' % (
  243. self.expected_regex.pattern, str(first_matching)))
  244. if self.obj_name:
  245. raise self.failureException(
  246. '%s not triggered by %s' % (exc_name, self.obj_name))
  247. else:
  248. raise self.failureException('%s not triggered' % exc_name)
  249. class Case(unittest.TestCase):
  250. def assertWarns(self, expected_warning):
  251. return _AssertWarnsContext(expected_warning, self, None)
  252. def assertWarnsRegex(self, expected_warning, expected_regex):
  253. return _AssertWarnsContext(expected_warning, self,
  254. None, expected_regex)
  255. @contextmanager
  256. def assertDeprecated(self):
  257. with self.assertWarnsRegex(CDeprecationWarning,
  258. r'scheduled for removal'):
  259. yield
  260. @contextmanager
  261. def assertPendingDeprecation(self):
  262. with self.assertWarnsRegex(CPendingDeprecationWarning,
  263. r'scheduled for deprecation'):
  264. yield
  265. def assertDictContainsSubset(self, expected, actual, msg=None):
  266. missing, mismatched = [], []
  267. for key, value in items(expected):
  268. if key not in actual:
  269. missing.append(key)
  270. elif value != actual[key]:
  271. mismatched.append('%s, expected: %s, actual: %s' % (
  272. safe_repr(key), safe_repr(value),
  273. safe_repr(actual[key])))
  274. if not (missing or mismatched):
  275. return
  276. standard_msg = ''
  277. if missing:
  278. standard_msg = 'Missing: %s' % ','.join(map(safe_repr, missing))
  279. if mismatched:
  280. if standard_msg:
  281. standard_msg += '; '
  282. standard_msg += 'Mismatched values: %s' % (
  283. ','.join(mismatched))
  284. self.fail(self._formatMessage(msg, standard_msg))
  285. def assertItemsEqual(self, expected_seq, actual_seq, msg=None):
  286. missing = unexpected = None
  287. try:
  288. expected = sorted(expected_seq)
  289. actual = sorted(actual_seq)
  290. except TypeError:
  291. # Unsortable items (example: set(), complex(), ...)
  292. expected = list(expected_seq)
  293. actual = list(actual_seq)
  294. missing, unexpected = unorderable_list_difference(
  295. expected, actual)
  296. else:
  297. return self.assertSequenceEqual(expected, actual, msg=msg)
  298. errors = []
  299. if missing:
  300. errors.append(
  301. 'Expected, but missing:\n %s' % (safe_repr(missing), )
  302. )
  303. if unexpected:
  304. errors.append(
  305. 'Unexpected, but present:\n %s' % (safe_repr(unexpected), )
  306. )
  307. if errors:
  308. standardMsg = '\n'.join(errors)
  309. self.fail(self._formatMessage(msg, standardMsg))
  310. def depends_on_current_app(fun):
  311. if inspect.isclass(fun):
  312. fun.contained = False
  313. else:
  314. @wraps(fun)
  315. def __inner(self, *args, **kwargs):
  316. self.app.set_current()
  317. return fun(self, *args, **kwargs)
  318. return __inner
  319. class AppCase(Case):
  320. contained = True
  321. def __init__(self, *args, **kwargs):
  322. super(AppCase, self).__init__(*args, **kwargs)
  323. if self.__class__.__dict__.get('setUp'):
  324. raise RuntimeError(
  325. CASE_REDEFINES_SETUP.format(name=qualname(self)),
  326. )
  327. if self.__class__.__dict__.get('tearDown'):
  328. raise RuntimeError(
  329. CASE_REDEFINES_TEARDOWN.format(name=qualname(self)),
  330. )
  331. def Celery(self, *args, **kwargs):
  332. return UnitApp(*args, **kwargs)
  333. def setUp(self):
  334. self._threads_at_setup = list(threading.enumerate())
  335. from celery import _state
  336. from celery import result
  337. result.task_join_will_block = \
  338. _state.task_join_will_block = lambda: False
  339. self._current_app = current_app()
  340. self._default_app = _state.default_app
  341. trap = Trap()
  342. self._prev_tls = _state._tls
  343. _state.set_default_app(trap)
  344. class NonTLS(object):
  345. current_app = trap
  346. _state._tls = NonTLS()
  347. self.app = self.Celery(set_as_current=False)
  348. if not self.contained:
  349. self.app.set_current()
  350. root = logging.getLogger()
  351. self.__rootlevel = root.level
  352. self.__roothandlers = root.handlers
  353. _state._set_task_join_will_block(False)
  354. try:
  355. self.setup()
  356. except:
  357. self._teardown_app()
  358. raise
  359. def _teardown_app(self):
  360. from celery.utils.log import LoggingProxy
  361. assert sys.stdout
  362. assert sys.stderr
  363. assert sys.__stdout__
  364. assert sys.__stderr__
  365. this = self._get_test_name()
  366. if isinstance(sys.stdout, LoggingProxy) or \
  367. isinstance(sys.__stdout__, LoggingProxy):
  368. raise RuntimeError(CASE_LOG_REDIRECT_EFFECT.format(this, 'stdout'))
  369. if isinstance(sys.stderr, LoggingProxy) or \
  370. isinstance(sys.__stderr__, LoggingProxy):
  371. raise RuntimeError(CASE_LOG_REDIRECT_EFFECT.format(this, 'stderr'))
  372. backend = self.app.__dict__.get('backend')
  373. if backend is not None:
  374. if isinstance(backend, CacheBackend):
  375. if isinstance(backend.client, DummyClient):
  376. backend.client.cache.clear()
  377. backend._cache.clear()
  378. from celery import _state
  379. _state._set_task_join_will_block(False)
  380. _state.set_default_app(self._default_app)
  381. _state._tls = self._prev_tls
  382. _state._tls.current_app = self._current_app
  383. if self.app is not self._current_app:
  384. self.app.close()
  385. self.app = None
  386. self.assertEqual(
  387. self._threads_at_setup, list(threading.enumerate()),
  388. )
  389. def _get_test_name(self):
  390. return '.'.join([self.__class__.__name__, self._testMethodName])
  391. def tearDown(self):
  392. try:
  393. self.teardown()
  394. finally:
  395. self._teardown_app()
  396. self.assert_no_logging_side_effect()
  397. def assert_no_logging_side_effect(self):
  398. this = self._get_test_name()
  399. root = logging.getLogger()
  400. if root.level != self.__rootlevel:
  401. raise RuntimeError(CASE_LOG_LEVEL_EFFECT.format(this))
  402. if root.handlers != self.__roothandlers:
  403. raise RuntimeError(CASE_LOG_HANDLER_EFFECT.format(this))
  404. def setup(self):
  405. pass
  406. def teardown(self):
  407. pass
  408. def get_handlers(logger):
  409. return [h for h in logger.handlers if not isinstance(h, NullHandler)]
  410. @contextmanager
  411. def wrap_logger(logger, loglevel=logging.ERROR):
  412. old_handlers = get_handlers(logger)
  413. sio = WhateverIO()
  414. siohandler = logging.StreamHandler(sio)
  415. logger.handlers = [siohandler]
  416. try:
  417. yield sio
  418. finally:
  419. logger.handlers = old_handlers
  420. def with_environ(env_name, env_value):
  421. def _envpatched(fun):
  422. @wraps(fun)
  423. def _patch_environ(*args, **kwargs):
  424. prev_val = os.environ.get(env_name)
  425. os.environ[env_name] = env_value
  426. try:
  427. return fun(*args, **kwargs)
  428. finally:
  429. os.environ[env_name] = prev_val or ''
  430. return _patch_environ
  431. return _envpatched
  432. def sleepdeprived(module=time):
  433. def _sleepdeprived(fun):
  434. @wraps(fun)
  435. def __sleepdeprived(*args, **kwargs):
  436. old_sleep = module.sleep
  437. module.sleep = noop
  438. try:
  439. return fun(*args, **kwargs)
  440. finally:
  441. module.sleep = old_sleep
  442. return __sleepdeprived
  443. return _sleepdeprived
  444. def skip_if_environ(env_var_name):
  445. def _wrap_test(fun):
  446. @wraps(fun)
  447. def _skips_if_environ(*args, **kwargs):
  448. if os.environ.get(env_var_name):
  449. raise SkipTest('SKIP %s: %s set\n' % (
  450. fun.__name__, env_var_name))
  451. return fun(*args, **kwargs)
  452. return _skips_if_environ
  453. return _wrap_test
  454. def _skip_test(reason, sign):
  455. def _wrap_test(fun):
  456. @wraps(fun)
  457. def _skipped_test(*args, **kwargs):
  458. raise SkipTest('%s: %s' % (sign, reason))
  459. return _skipped_test
  460. return _wrap_test
  461. def todo(reason):
  462. """TODO test decorator."""
  463. return _skip_test(reason, 'TODO')
  464. def skip(reason):
  465. """Skip test decorator."""
  466. return _skip_test(reason, 'SKIP')
  467. def skip_if(predicate, reason):
  468. """Skip test if predicate is :const:`True`."""
  469. def _inner(fun):
  470. return predicate and skip(reason)(fun) or fun
  471. return _inner
  472. def skip_unless(predicate, reason):
  473. """Skip test if predicate is :const:`False`."""
  474. return skip_if(not predicate, reason)
  475. # Taken from
  476. # http://bitbucket.org/runeh/snippets/src/tip/missing_modules.py
  477. @contextmanager
  478. def mask_modules(*modnames):
  479. """Ban some modules from being importable inside the context
  480. For example:
  481. >>> with mask_modules('sys'):
  482. ... try:
  483. ... import sys
  484. ... except ImportError:
  485. ... print('sys not found')
  486. sys not found
  487. >>> import sys # noqa
  488. >>> sys.version
  489. (2, 5, 2, 'final', 0)
  490. """
  491. realimport = builtins.__import__
  492. def myimp(name, *args, **kwargs):
  493. if name in modnames:
  494. raise ImportError('No module named %s' % name)
  495. else:
  496. return realimport(name, *args, **kwargs)
  497. builtins.__import__ = myimp
  498. try:
  499. yield True
  500. finally:
  501. builtins.__import__ = realimport
  502. @contextmanager
  503. def override_stdouts():
  504. """Override `sys.stdout` and `sys.stderr` with `WhateverIO`."""
  505. prev_out, prev_err = sys.stdout, sys.stderr
  506. mystdout, mystderr = WhateverIO(), WhateverIO()
  507. sys.stdout = sys.__stdout__ = mystdout
  508. sys.stderr = sys.__stderr__ = mystderr
  509. try:
  510. yield mystdout, mystderr
  511. finally:
  512. sys.stdout = sys.__stdout__ = prev_out
  513. sys.stderr = sys.__stderr__ = prev_err
  514. def _old_patch(module, name, mocked):
  515. module = importlib.import_module(module)
  516. def _patch(fun):
  517. @wraps(fun)
  518. def __patched(*args, **kwargs):
  519. prev = getattr(module, name)
  520. setattr(module, name, mocked)
  521. try:
  522. return fun(*args, **kwargs)
  523. finally:
  524. setattr(module, name, prev)
  525. return __patched
  526. return _patch
  527. @contextmanager
  528. def replace_module_value(module, name, value=None):
  529. has_prev = hasattr(module, name)
  530. prev = getattr(module, name, None)
  531. if value:
  532. setattr(module, name, value)
  533. else:
  534. try:
  535. delattr(module, name)
  536. except AttributeError:
  537. pass
  538. try:
  539. yield
  540. finally:
  541. if prev is not None:
  542. setattr(sys, name, prev)
  543. if not has_prev:
  544. try:
  545. delattr(module, name)
  546. except AttributeError:
  547. pass
  548. pypy_version = partial(
  549. replace_module_value, sys, 'pypy_version_info',
  550. )
  551. platform_pyimp = partial(
  552. replace_module_value, platform, 'python_implementation',
  553. )
  554. @contextmanager
  555. def sys_platform(value):
  556. prev, sys.platform = sys.platform, value
  557. try:
  558. yield
  559. finally:
  560. sys.platform = prev
  561. @contextmanager
  562. def reset_modules(*modules):
  563. prev = dict((k, sys.modules.pop(k)) for k in modules if k in sys.modules)
  564. try:
  565. yield
  566. finally:
  567. sys.modules.update(prev)
  568. @contextmanager
  569. def patch_modules(*modules):
  570. prev = {}
  571. for mod in modules:
  572. prev[mod] = sys.modules.get(mod)
  573. sys.modules[mod] = ModuleType(mod)
  574. try:
  575. yield
  576. finally:
  577. for name, mod in items(prev):
  578. if mod is None:
  579. sys.modules.pop(name, None)
  580. else:
  581. sys.modules[name] = mod
  582. @contextmanager
  583. def mock_module(*names):
  584. prev = {}
  585. class MockModule(ModuleType):
  586. def __getattr__(self, attr):
  587. setattr(self, attr, Mock())
  588. return ModuleType.__getattribute__(self, attr)
  589. mods = []
  590. for name in names:
  591. try:
  592. prev[name] = sys.modules[name]
  593. except KeyError:
  594. pass
  595. mod = sys.modules[name] = MockModule(name)
  596. mods.append(mod)
  597. try:
  598. yield mods
  599. finally:
  600. for name in names:
  601. try:
  602. sys.modules[name] = prev[name]
  603. except KeyError:
  604. try:
  605. del(sys.modules[name])
  606. except KeyError:
  607. pass
  608. @contextmanager
  609. def mock_context(mock, typ=Mock):
  610. context = mock.return_value = Mock()
  611. context.__enter__ = typ()
  612. context.__exit__ = typ()
  613. def on_exit(*x):
  614. if x[0]:
  615. reraise(x[0], x[1], x[2])
  616. context.__exit__.side_effect = on_exit
  617. context.__enter__.return_value = context
  618. try:
  619. yield context
  620. finally:
  621. context.reset()
  622. @contextmanager
  623. def mock_open(typ=WhateverIO, side_effect=None):
  624. with patch(open_fqdn) as open_:
  625. with mock_context(open_) as context:
  626. if side_effect is not None:
  627. context.__enter__.side_effect = side_effect
  628. val = context.__enter__.return_value = typ()
  629. val.__exit__ = Mock()
  630. yield val
  631. def patch_many(*targets):
  632. return nested(*[patch(target) for target in targets])
  633. @contextmanager
  634. def assert_signal_called(signal, **expected):
  635. handler = Mock()
  636. call_handler = partial(handler)
  637. signal.connect(call_handler)
  638. try:
  639. yield handler
  640. finally:
  641. signal.disconnect(call_handler)
  642. handler.assert_called_with(signal=signal, **expected)
  643. def skip_if_pypy(fun):
  644. @wraps(fun)
  645. def _inner(*args, **kwargs):
  646. if getattr(sys, 'pypy_version_info', None):
  647. raise SkipTest('does not work on PyPy')
  648. return fun(*args, **kwargs)
  649. return _inner
  650. def skip_if_jython(fun):
  651. @wraps(fun)
  652. def _inner(*args, **kwargs):
  653. if sys.platform.startswith('java'):
  654. raise SkipTest('does not work on Jython')
  655. return fun(*args, **kwargs)
  656. return _inner
  657. def body_from_sig(app, sig, utc=True):
  658. sig.freeze()
  659. callbacks = sig.options.pop('link', None)
  660. errbacks = sig.options.pop('link_error', None)
  661. countdown = sig.options.pop('countdown', None)
  662. if countdown:
  663. eta = app.now() + timedelta(seconds=countdown)
  664. else:
  665. eta = sig.options.pop('eta', None)
  666. if eta and isinstance(eta, datetime):
  667. eta = eta.isoformat()
  668. expires = sig.options.pop('expires', None)
  669. if expires and isinstance(expires, numbers.Real):
  670. expires = app.now() + timedelta(seconds=expires)
  671. if expires and isinstance(expires, datetime):
  672. expires = expires.isoformat()
  673. return {
  674. 'task': sig.task,
  675. 'id': sig.id,
  676. 'args': sig.args,
  677. 'kwargs': sig.kwargs,
  678. 'callbacks': [dict(s) for s in callbacks] if callbacks else None,
  679. 'errbacks': [dict(s) for s in errbacks] if errbacks else None,
  680. 'eta': eta,
  681. 'utc': utc,
  682. 'expires': expires,
  683. }
  684. @contextmanager
  685. def restore_logging():
  686. outs = sys.stdout, sys.stderr, sys.__stdout__, sys.__stderr__
  687. root = logging.getLogger()
  688. level = root.level
  689. handlers = root.handlers
  690. try:
  691. yield
  692. finally:
  693. sys.stdout, sys.stderr, sys.__stdout__, sys.__stderr__ = outs
  694. root.level = level
  695. root.handlers[:] = handlers