case.py 28 KB


  1. from __future__ import absolute_import
  2. try:
  3. import unittest # noqa
  4. unittest.skip
  5. from unittest.util import safe_repr, unorderable_list_difference
  6. except AttributeError:
  7. import unittest2 as unittest # noqa
  8. from unittest2.util import safe_repr, unorderable_list_difference # noqa
  9. import importlib
  10. import inspect
  11. import logging
  12. import numbers
  13. import os
  14. import platform
  15. import re
  16. import sys
  17. import threading
  18. import time
  19. import types
  20. import warnings
  21. from contextlib import contextmanager
  22. from copy import deepcopy
  23. from datetime import datetime, timedelta
  24. from functools import partial, wraps
  25. from types import ModuleType
  26. try:
  27. from unittest import mock
  28. except ImportError:
  29. import mock # noqa
  30. from nose import SkipTest
  31. from kombu import Queue
  32. from kombu.log import NullHandler
  33. from kombu.utils import symbol_by_name
  34. from celery import Celery
  35. from celery.app import current_app
  36. from celery.backends.cache import CacheBackend, DummyClient
  37. from celery.exceptions import CDeprecationWarning, CPendingDeprecationWarning
  38. from celery.five import (
  39. WhateverIO, builtins, items, reraise,
  40. string_t, values, open_fqdn,
  41. )
  42. from celery.utils.functional import noop
  43. from celery.utils.imports import qualname
  44. __all__ = [
  45. 'Case', 'AppCase', 'Mock', 'MagicMock', 'ANY', 'TaskMessage',
  46. 'patch', 'call', 'sentinel', 'skip_unless_module',
  47. 'wrap_logger', 'with_environ', 'sleepdeprived',
  48. 'skip_if_environ', 'todo', 'skip', 'skip_if',
  49. 'skip_unless', 'mask_modules', 'override_stdouts', 'mock_module',
  50. 'replace_module_value', 'sys_platform', 'reset_modules',
  51. 'patch_modules', 'mock_context', 'mock_open',
  52. 'assert_signal_called', 'skip_if_pypy',
  53. 'skip_if_jython', 'task_message_from_sig', 'restore_logging',
  54. ]
  55. patch = mock.patch
  56. call = mock.call
  57. sentinel = mock.sentinel
  58. MagicMock = mock.MagicMock
  59. ANY = mock.ANY
  60. PY3 = sys.version_info[0] == 3
  61. CASE_REDEFINES_SETUP = """\
  62. {name} (subclass of AppCase) redefines private "setUp", should be: "setup"\
  63. """
  64. CASE_REDEFINES_TEARDOWN = """\
  65. {name} (subclass of AppCase) redefines private "tearDown", \
  66. should be: "teardown"\
  67. """
  68. CASE_LOG_REDIRECT_EFFECT = """\
  69. Test {0} did not disable LoggingProxy for {1}\
  70. """
  71. CASE_LOG_LEVEL_EFFECT = """\
  72. Test {0} Modified the level of the root logger\
  73. """
  74. CASE_LOG_HANDLER_EFFECT = """\
  75. Test {0} Modified handlers for the root logger\
  76. """
  77. CELERY_TEST_CONFIG = {
  78. #: Don't want log output when running suite.
  79. 'worker_hijack_root_logger': False,
  80. 'worker_log_color': False,
  81. 'task_send_error_emails': False,
  82. 'task_default_queue': 'testcelery',
  83. 'task_default_exchange': 'testcelery',
  84. 'task_default_routing_key': 'testcelery',
  85. 'task_queues': (
  86. Queue('testcelery', routing_key='testcelery'),
  87. ),
  88. 'accept_content': ('json', 'pickle'),
  89. 'enable_utc': True,
  90. 'timezone': 'UTC',
  91. # Mongo results tests (only executed if installed and running)
  92. 'mongodb_backend_settings': {
  93. 'host': os.environ.get('MONGO_HOST') or 'localhost',
  94. 'port': os.environ.get('MONGO_PORT') or 27017,
  95. 'database': os.environ.get('MONGO_DB') or 'celery_unittests',
  96. 'taskmeta_collection': (os.environ.get('MONGO_TASKMETA_COLLECTION') or
  97. 'taskmeta_collection'),
  98. 'user': os.environ.get('MONGO_USER'),
  99. 'password': os.environ.get('MONGO_PASSWORD'),
  100. }
  101. }
  102. class Trap(object):
  103. def __getattr__(self, name):
  104. raise RuntimeError('Test depends on current_app')
  105. class UnitLogging(symbol_by_name(Celery.log_cls)):
  106. def __init__(self, *args, **kwargs):
  107. super(UnitLogging, self).__init__(*args, **kwargs)
  108. self.already_setup = True
  109. def UnitApp(name=None, set_as_current=False, log=UnitLogging,
  110. broker='memory://', backend='cache+memory://', **kwargs):
  111. app = Celery(name or 'celery.tests',
  112. set_as_current=set_as_current,
  113. log=log, broker=broker, backend=backend,
  114. **kwargs)
  115. app.add_defaults(deepcopy(CELERY_TEST_CONFIG))
  116. return app
  117. class Mock(mock.Mock):
  118. def __init__(self, *args, **kwargs):
  119. attrs = kwargs.pop('attrs', None) or {}
  120. super(Mock, self).__init__(*args, **kwargs)
  121. for attr_name, attr_value in items(attrs):
  122. setattr(self, attr_name, attr_value)
  123. class _ContextMock(Mock):
  124. """Dummy class implementing __enter__ and __exit__
  125. as the with statement requires these to be implemented
  126. in the class, not just the instance."""
  127. def __enter__(self):
  128. return self
  129. def __exit__(self, *exc_info):
  130. pass
  131. def ContextMock(*args, **kwargs):
  132. obj = _ContextMock(*args, **kwargs)
  133. obj.attach_mock(_ContextMock(), '__enter__')
  134. obj.attach_mock(_ContextMock(), '__exit__')
  135. obj.__enter__.return_value = obj
  136. # if __exit__ return a value the exception is ignored,
  137. # so it must return None here.
  138. obj.__exit__.return_value = None
  139. return obj
  140. def _bind(f, o):
  141. @wraps(f)
  142. def bound_meth(*fargs, **fkwargs):
  143. return f(o, *fargs, **fkwargs)
  144. return bound_meth
  145. if PY3: # pragma: no cover
  146. def _get_class_fun(meth):
  147. return meth
  148. else:
  149. def _get_class_fun(meth):
  150. return meth.__func__
  151. class MockCallbacks(object):
  152. def __new__(cls, *args, **kwargs):
  153. r = Mock(name=cls.__name__)
  154. _get_class_fun(cls.__init__)(r, *args, **kwargs)
  155. for key, value in items(vars(cls)):
  156. if key not in ('__dict__', '__weakref__', '__new__', '__init__'):
  157. if inspect.ismethod(value) or inspect.isfunction(value):
  158. r.__getattr__(key).side_effect = _bind(value, r)
  159. else:
  160. r.__setattr__(key, value)
  161. return r
  162. def skip_unless_module(module):
  163. def _inner(fun):
  164. @wraps(fun)
  165. def __inner(*args, **kwargs):
  166. try:
  167. importlib.import_module(module)
  168. except ImportError:
  169. raise SkipTest('Does not have %s' % (module,))
  170. return fun(*args, **kwargs)
  171. return __inner
  172. return _inner
  173. # -- adds assertWarns from recent unittest2, not in Python 2.7.
  174. class _AssertRaisesBaseContext(object):
  175. def __init__(self, expected, test_case, callable_obj=None,
  176. expected_regex=None):
  177. self.expected = expected
  178. self.failureException = test_case.failureException
  179. self.obj_name = None
  180. if isinstance(expected_regex, string_t):
  181. expected_regex = re.compile(expected_regex)
  182. self.expected_regex = expected_regex
  183. def _is_magic_module(m):
  184. # some libraries create custom module types that are lazily
  185. # lodaded, e.g. Django installs some modules in sys.modules that
  186. # will load _tkinter and other shit when touched.
  187. # pyflakes refuses to accept 'noqa' for this isinstance.
  188. cls, modtype = type(m), types.ModuleType
  189. try:
  190. variables = vars(cls)
  191. except TypeError:
  192. return True
  193. else:
  194. return (cls is not modtype and (
  195. '__getattr__' in variables or
  196. '__getattribute__' in variables))
  197. class _AssertWarnsContext(_AssertRaisesBaseContext):
  198. """A context manager used to implement TestCase.assertWarns* methods."""
  199. def __enter__(self):
  200. # The __warningregistry__'s need to be in a pristine state for tests
  201. # to work properly.
  202. warnings.resetwarnings()
  203. for v in list(values(sys.modules)):
  204. # do not evaluate Django moved modules and other lazily
  205. # initialized modules.
  206. if v and not _is_magic_module(v):
  207. # use raw __getattribute__ to protect even better from
  208. # lazily loaded modules
  209. try:
  210. object.__getattribute__(v, '__warningregistry__')
  211. except AttributeError:
  212. pass
  213. else:
  214. object.__setattr__(v, '__warningregistry__', {})
  215. self.warnings_manager = warnings.catch_warnings(record=True)
  216. self.warnings = self.warnings_manager.__enter__()
  217. warnings.simplefilter('always', self.expected)
  218. return self
  219. def __exit__(self, exc_type, exc_value, tb):
  220. self.warnings_manager.__exit__(exc_type, exc_value, tb)
  221. if exc_type is not None:
  222. # let unexpected exceptions pass through
  223. return
  224. try:
  225. exc_name = self.expected.__name__
  226. except AttributeError:
  227. exc_name = str(self.expected)
  228. first_matching = None
  229. for m in self.warnings:
  230. w = m.message
  231. if not isinstance(w, self.expected):
  232. continue
  233. if first_matching is None:
  234. first_matching = w
  235. if (self.expected_regex is not None and
  236. not self.expected_regex.search(str(w))):
  237. continue
  238. # store warning for later retrieval
  239. self.warning = w
  240. self.filename = m.filename
  241. self.lineno = m.lineno
  242. return
  243. # Now we simply try to choose a helpful failure message
  244. if first_matching is not None:
  245. raise self.failureException(
  246. '%r does not match %r' % (
  247. self.expected_regex.pattern, str(first_matching)))
  248. if self.obj_name:
  249. raise self.failureException(
  250. '%s not triggered by %s' % (exc_name, self.obj_name))
  251. else:
  252. raise self.failureException('%s not triggered' % exc_name)
  253. def alive_threads():
  254. return [thread for thread in threading.enumerate() if thread.is_alive()]
  255. class Case(unittest.TestCase):
  256. def patch(self, *path, **options):
  257. manager = patch(".".join(path), **options)
  258. patched = manager.start()
  259. self.addCleanup(manager.stop)
  260. return patched
  261. def mock_modules(self, *mods):
  262. modules = []
  263. for mod in mods:
  264. mod = mod.split('.')
  265. modules.extend(reversed([
  266. '.'.join(mod[:-i] if i else mod) for i in range(len(mod))
  267. ]))
  268. modules = sorted(set(modules))
  269. return self.wrap_context(mock_module(*modules))
  270. def on_nth_call_do(self, mock, side_effect, n=1):
  271. def on_call(*args, **kwargs):
  272. if mock.call_count >= n:
  273. mock.side_effect = side_effect
  274. return mock.return_value
  275. mock.side_effect = on_call
  276. return mock
  277. def on_nth_call_return(self, mock, retval, n=1):
  278. def on_call(*args, **kwargs):
  279. if mock.call_count >= n:
  280. mock.return_value = retval
  281. return mock.return_value
  282. mock.side_effect = on_call
  283. return mock
  284. def mask_modules(self, *modules):
  285. self.wrap_context(mask_modules(*modules))
  286. def wrap_context(self, context):
  287. ret = context.__enter__()
  288. self.addCleanup(partial(context.__exit__, None, None, None))
  289. return ret
  290. def mock_environ(self, env_name, env_value):
  291. return self.wrap_context(mock_environ(env_name, env_value))
  292. def assertWarns(self, expected_warning):
  293. return _AssertWarnsContext(expected_warning, self, None)
  294. def assertWarnsRegex(self, expected_warning, expected_regex):
  295. return _AssertWarnsContext(expected_warning, self,
  296. None, expected_regex)
  297. @contextmanager
  298. def assertDeprecated(self):
  299. with self.assertWarnsRegex(CDeprecationWarning,
  300. r'scheduled for removal'):
  301. yield
  302. @contextmanager
  303. def assertPendingDeprecation(self):
  304. with self.assertWarnsRegex(CPendingDeprecationWarning,
  305. r'scheduled for deprecation'):
  306. yield
  307. def assertDictContainsSubset(self, expected, actual, msg=None):
  308. missing, mismatched = [], []
  309. for key, value in items(expected):
  310. if key not in actual:
  311. missing.append(key)
  312. elif value != actual[key]:
  313. mismatched.append('%s, expected: %s, actual: %s' % (
  314. safe_repr(key), safe_repr(value),
  315. safe_repr(actual[key])))
  316. if not (missing or mismatched):
  317. return
  318. standard_msg = ''
  319. if missing:
  320. standard_msg = 'Missing: %s' % ','.join(map(safe_repr, missing))
  321. if mismatched:
  322. if standard_msg:
  323. standard_msg += '; '
  324. standard_msg += 'Mismatched values: %s' % (
  325. ','.join(mismatched))
  326. self.fail(self._formatMessage(msg, standard_msg))
  327. def assertItemsEqual(self, expected_seq, actual_seq, msg=None):
  328. missing = unexpected = None
  329. try:
  330. expected = sorted(expected_seq)
  331. actual = sorted(actual_seq)
  332. except TypeError:
  333. # Unsortable items (example: set(), complex(), ...)
  334. expected = list(expected_seq)
  335. actual = list(actual_seq)
  336. missing, unexpected = unorderable_list_difference(
  337. expected, actual)
  338. else:
  339. return self.assertSequenceEqual(expected, actual, msg=msg)
  340. errors = []
  341. if missing:
  342. errors.append(
  343. 'Expected, but missing:\n %s' % (safe_repr(missing),)
  344. )
  345. if unexpected:
  346. errors.append(
  347. 'Unexpected, but present:\n %s' % (safe_repr(unexpected),)
  348. )
  349. if errors:
  350. standardMsg = '\n'.join(errors)
  351. self.fail(self._formatMessage(msg, standardMsg))
  352. def depends_on_current_app(fun):
  353. if inspect.isclass(fun):
  354. fun.contained = False
  355. else:
  356. @wraps(fun)
  357. def __inner(self, *args, **kwargs):
  358. self.app.set_current()
  359. return fun(self, *args, **kwargs)
  360. return __inner
  361. class AppCase(Case):
  362. contained = True
  363. _threads_at_startup = [None]
  364. def __init__(self, *args, **kwargs):
  365. super(AppCase, self).__init__(*args, **kwargs)
  366. if self.__class__.__dict__.get('setUp'):
  367. raise RuntimeError(
  368. CASE_REDEFINES_SETUP.format(name=qualname(self)),
  369. )
  370. if self.__class__.__dict__.get('tearDown'):
  371. raise RuntimeError(
  372. CASE_REDEFINES_TEARDOWN.format(name=qualname(self)),
  373. )
  374. def Celery(self, *args, **kwargs):
  375. return UnitApp(*args, **kwargs)
  376. def threads_at_startup(self):
  377. if self._threads_at_startup[0] is None:
  378. self._threads_at_startup[0] = alive_threads()
  379. return self._threads_at_startup[0]
  380. def setUp(self):
  381. self._threads_at_setup = self.threads_at_startup()
  382. from celery import _state
  383. from celery import result
  384. self._prev_res_join_block = result.task_join_will_block
  385. self._prev_state_join_block = _state.task_join_will_block
  386. result.task_join_will_block = \
  387. _state.task_join_will_block = lambda: False
  388. self._current_app = current_app()
  389. self._default_app = _state.default_app
  390. trap = Trap()
  391. self._prev_tls = _state._tls
  392. _state.set_default_app(trap)
  393. class NonTLS(object):
  394. current_app = trap
  395. _state._tls = NonTLS()
  396. self.app = self.Celery(set_as_current=False)
  397. if not self.contained:
  398. self.app.set_current()
  399. root = logging.getLogger()
  400. self.__rootlevel = root.level
  401. self.__roothandlers = root.handlers
  402. _state._set_task_join_will_block(False)
  403. try:
  404. self.setup()
  405. except:
  406. self._teardown_app()
  407. raise
  408. def _teardown_app(self):
  409. from celery import _state
  410. from celery import result
  411. from celery.utils.log import LoggingProxy
  412. assert sys.stdout
  413. assert sys.stderr
  414. assert sys.__stdout__
  415. assert sys.__stderr__
  416. this = self._get_test_name()
  417. result.task_join_will_block = self._prev_res_join_block
  418. _state.task_join_will_block = self._prev_state_join_block
  419. if isinstance(sys.stdout, (LoggingProxy, Mock)) or \
  420. isinstance(sys.__stdout__, (LoggingProxy, Mock)):
  421. raise RuntimeError(CASE_LOG_REDIRECT_EFFECT.format(this, 'stdout'))
  422. if isinstance(sys.stderr, (LoggingProxy, Mock)) or \
  423. isinstance(sys.__stderr__, (LoggingProxy, Mock)):
  424. raise RuntimeError(CASE_LOG_REDIRECT_EFFECT.format(this, 'stderr'))
  425. backend = self.app.__dict__.get('backend')
  426. if backend is not None:
  427. if isinstance(backend, CacheBackend):
  428. if isinstance(backend.client, DummyClient):
  429. backend.client.cache.clear()
  430. backend._cache.clear()
  431. from celery import _state
  432. _state._set_task_join_will_block(False)
  433. _state.set_default_app(self._default_app)
  434. _state._tls = self._prev_tls
  435. _state._tls.current_app = self._current_app
  436. if self.app is not self._current_app:
  437. self.app.close()
  438. self.app = None
  439. self.assertEqual(self._threads_at_setup, alive_threads())
  440. # Make sure no test left the shutdown flags enabled.
  441. from celery.worker import state as worker_state
  442. # check for EX_OK
  443. self.assertIsNot(worker_state.should_stop, False)
  444. self.assertIsNot(worker_state.should_terminate, False)
  445. # check for other true values
  446. self.assertFalse(worker_state.should_stop)
  447. self.assertFalse(worker_state.should_terminate)
  448. def _get_test_name(self):
  449. return '.'.join([self.__class__.__name__, self._testMethodName])
  450. def tearDown(self):
  451. try:
  452. self.teardown()
  453. finally:
  454. self._teardown_app()
  455. self.assert_no_logging_side_effect()
  456. def assert_no_logging_side_effect(self):
  457. this = self._get_test_name()
  458. root = logging.getLogger()
  459. if root.level != self.__rootlevel:
  460. raise RuntimeError(CASE_LOG_LEVEL_EFFECT.format(this))
  461. if root.handlers != self.__roothandlers:
  462. raise RuntimeError(CASE_LOG_HANDLER_EFFECT.format(this))
  463. def setup(self):
  464. pass
  465. def teardown(self):
  466. pass
  467. def get_handlers(logger):
  468. return [h for h in logger.handlers if not isinstance(h, NullHandler)]
  469. @contextmanager
  470. def wrap_logger(logger, loglevel=logging.ERROR):
  471. old_handlers = get_handlers(logger)
  472. sio = WhateverIO()
  473. siohandler = logging.StreamHandler(sio)
  474. logger.handlers = [siohandler]
  475. try:
  476. yield sio
  477. finally:
  478. logger.handlers = old_handlers
  479. @contextmanager
  480. def mock_environ(env_name, env_value):
  481. sentinel = object()
  482. prev_val = os.environ.get(env_name, sentinel)
  483. os.environ[env_name] = env_value
  484. try:
  485. yield env_value
  486. finally:
  487. if prev_val is sentinel:
  488. os.environ.pop(env_name, None)
  489. else:
  490. os.environ[env_name] = prev_val
  491. def with_environ(env_name, env_value):
  492. def _envpatched(fun):
  493. @wraps(fun)
  494. def _patch_environ(*args, **kwargs):
  495. with mock_environ(env_name, env_value):
  496. return fun(*args, **kwargs)
  497. return _patch_environ
  498. return _envpatched
  499. def sleepdeprived(module=time):
  500. def _sleepdeprived(fun):
  501. @wraps(fun)
  502. def __sleepdeprived(*args, **kwargs):
  503. old_sleep = module.sleep
  504. module.sleep = noop
  505. try:
  506. return fun(*args, **kwargs)
  507. finally:
  508. module.sleep = old_sleep
  509. return __sleepdeprived
  510. return _sleepdeprived
  511. def skip_if_environ(env_var_name):
  512. def _wrap_test(fun):
  513. @wraps(fun)
  514. def _skips_if_environ(*args, **kwargs):
  515. if os.environ.get(env_var_name):
  516. raise SkipTest('SKIP %s: %s set\n' % (
  517. fun.__name__, env_var_name))
  518. return fun(*args, **kwargs)
  519. return _skips_if_environ
  520. return _wrap_test
  521. def _skip_test(reason, sign):
  522. def _wrap_test(fun):
  523. @wraps(fun)
  524. def _skipped_test(*args, **kwargs):
  525. raise SkipTest('%s: %s' % (sign, reason))
  526. return _skipped_test
  527. return _wrap_test
  528. def todo(reason):
  529. """TODO test decorator."""
  530. return _skip_test(reason, 'TODO')
  531. def skip(reason):
  532. """Skip test decorator."""
  533. return _skip_test(reason, 'SKIP')
  534. def skip_if(predicate, reason):
  535. """Skip test if predicate is :const:`True`."""
  536. def _inner(fun):
  537. return predicate and skip(reason)(fun) or fun
  538. return _inner
  539. def skip_unless(predicate, reason):
  540. """Skip test if predicate is :const:`False`."""
  541. return skip_if(not predicate, reason)
  542. # Taken from
  543. # http://bitbucket.org/runeh/snippets/src/tip/missing_modules.py
  544. @contextmanager
  545. def mask_modules(*modnames):
  546. """Ban some modules from being importable inside the context
  547. For example:
  548. >>> with mask_modules('sys'):
  549. ... try:
  550. ... import sys
  551. ... except ImportError:
  552. ... print('sys not found')
  553. sys not found
  554. >>> import sys # noqa
  555. >>> sys.version
  556. (2, 5, 2, 'final', 0)
  557. """
  558. realimport = builtins.__import__
  559. def myimp(name, *args, **kwargs):
  560. if name in modnames:
  561. raise ImportError('No module named %s' % name)
  562. else:
  563. return realimport(name, *args, **kwargs)
  564. builtins.__import__ = myimp
  565. try:
  566. yield True
  567. finally:
  568. builtins.__import__ = realimport
  569. @contextmanager
  570. def override_stdouts():
  571. """Override `sys.stdout` and `sys.stderr` with `WhateverIO`."""
  572. prev_out, prev_err = sys.stdout, sys.stderr
  573. mystdout, mystderr = WhateverIO(), WhateverIO()
  574. sys.stdout = sys.__stdout__ = mystdout
  575. sys.stderr = sys.__stderr__ = mystderr
  576. try:
  577. yield mystdout, mystderr
  578. finally:
  579. sys.stdout = sys.__stdout__ = prev_out
  580. sys.stderr = sys.__stderr__ = prev_err
  581. def _old_patch(module, name, mocked):
  582. module = importlib.import_module(module)
  583. def _patch(fun):
  584. @wraps(fun)
  585. def __patched(*args, **kwargs):
  586. prev = getattr(module, name)
  587. setattr(module, name, mocked)
  588. try:
  589. return fun(*args, **kwargs)
  590. finally:
  591. setattr(module, name, prev)
  592. return __patched
  593. return _patch
  594. @contextmanager
  595. def replace_module_value(module, name, value=None):
  596. has_prev = hasattr(module, name)
  597. prev = getattr(module, name, None)
  598. if value:
  599. setattr(module, name, value)
  600. else:
  601. try:
  602. delattr(module, name)
  603. except AttributeError:
  604. pass
  605. try:
  606. yield
  607. finally:
  608. if prev is not None:
  609. setattr(module, name, prev)
  610. if not has_prev:
  611. try:
  612. delattr(module, name)
  613. except AttributeError:
  614. pass
  615. pypy_version = partial(
  616. replace_module_value, sys, 'pypy_version_info',
  617. )
  618. platform_pyimp = partial(
  619. replace_module_value, platform, 'python_implementation',
  620. )
  621. @contextmanager
  622. def sys_platform(value):
  623. prev, sys.platform = sys.platform, value
  624. try:
  625. yield
  626. finally:
  627. sys.platform = prev
  628. @contextmanager
  629. def reset_modules(*modules):
  630. prev = {k: sys.modules.pop(k) for k in modules if k in sys.modules}
  631. try:
  632. yield
  633. finally:
  634. sys.modules.update(prev)
  635. @contextmanager
  636. def patch_modules(*modules):
  637. prev = {}
  638. for mod in modules:
  639. prev[mod] = sys.modules.get(mod)
  640. sys.modules[mod] = ModuleType(mod)
  641. try:
  642. yield
  643. finally:
  644. for name, mod in items(prev):
  645. if mod is None:
  646. sys.modules.pop(name, None)
  647. else:
  648. sys.modules[name] = mod
  649. @contextmanager
  650. def mock_module(*names):
  651. prev = {}
  652. class MockModule(ModuleType):
  653. def __getattr__(self, attr):
  654. setattr(self, attr, Mock())
  655. return ModuleType.__getattribute__(self, attr)
  656. mods = []
  657. for name in names:
  658. try:
  659. prev[name] = sys.modules[name]
  660. except KeyError:
  661. pass
  662. mod = sys.modules[name] = MockModule(name)
  663. mods.append(mod)
  664. try:
  665. yield mods
  666. finally:
  667. for name in names:
  668. try:
  669. sys.modules[name] = prev[name]
  670. except KeyError:
  671. try:
  672. del(sys.modules[name])
  673. except KeyError:
  674. pass
  675. @contextmanager
  676. def mock_context(mock, typ=Mock):
  677. context = mock.return_value = Mock()
  678. context.__enter__ = typ()
  679. context.__exit__ = typ()
  680. def on_exit(*x):
  681. if x[0]:
  682. reraise(x[0], x[1], x[2])
  683. context.__exit__.side_effect = on_exit
  684. context.__enter__.return_value = context
  685. try:
  686. yield context
  687. finally:
  688. context.reset()
  689. @contextmanager
  690. def mock_open(typ=WhateverIO, side_effect=None):
  691. with patch(open_fqdn) as open_:
  692. with mock_context(open_) as context:
  693. if side_effect is not None:
  694. context.__enter__.side_effect = side_effect
  695. val = context.__enter__.return_value = typ()
  696. val.__exit__ = Mock()
  697. yield val
  698. @contextmanager
  699. def assert_signal_called(signal, **expected):
  700. handler = Mock()
  701. call_handler = partial(handler)
  702. signal.connect(call_handler)
  703. try:
  704. yield handler
  705. finally:
  706. signal.disconnect(call_handler)
  707. handler.assert_called_with(signal=signal, **expected)
  708. def skip_if_pypy(fun):
  709. @wraps(fun)
  710. def _inner(*args, **kwargs):
  711. if getattr(sys, 'pypy_version_info', None):
  712. raise SkipTest('does not work on PyPy')
  713. return fun(*args, **kwargs)
  714. return _inner
  715. def skip_if_jython(fun):
  716. @wraps(fun)
  717. def _inner(*args, **kwargs):
  718. if sys.platform.startswith('java'):
  719. raise SkipTest('does not work on Jython')
  720. return fun(*args, **kwargs)
  721. return _inner
  722. def TaskMessage(name, id=None, args=(), kwargs={}, callbacks=None,
  723. errbacks=None, chain=None, shadow=None, utc=None, **options):
  724. from celery import uuid
  725. from kombu.serialization import dumps
  726. id = id or uuid()
  727. message = Mock(name='TaskMessage-{0}'.format(id))
  728. message.headers = {
  729. 'id': id,
  730. 'task': name,
  731. 'shadow': shadow,
  732. }
  733. embed = {'callbacks': callbacks, 'errbacks': errbacks, 'chain': chain}
  734. message.headers.update(options)
  735. message.content_type, message.content_encoding, message.body = dumps(
  736. (args, kwargs, embed), serializer='json',
  737. )
  738. message.payload = (args, kwargs, embed)
  739. return message
  740. def TaskMessage1(name, id=None, args=(), kwargs={}, callbacks=None,
  741. errbacks=None, chain=None, **options):
  742. from celery import uuid
  743. from kombu.serialization import dumps
  744. id = id or uuid()
  745. message = Mock(name='TaskMessage-{0}'.format(id))
  746. message.headers = {}
  747. message.payload = {
  748. 'task': name,
  749. 'id': id,
  750. 'args': args,
  751. 'kwargs': kwargs,
  752. 'callbacks': callbacks,
  753. 'errbacks': errbacks,
  754. }
  755. message.payload.update(options)
  756. message.content_type, message.content_encoding, message.body = dumps(
  757. message.payload,
  758. )
  759. return message
  760. def task_message_from_sig(app, sig, utc=True, TaskMessage=TaskMessage):
  761. sig.freeze()
  762. callbacks = sig.options.pop('link', None)
  763. errbacks = sig.options.pop('link_error', None)
  764. countdown = sig.options.pop('countdown', None)
  765. if countdown:
  766. eta = app.now() + timedelta(seconds=countdown)
  767. else:
  768. eta = sig.options.pop('eta', None)
  769. if eta and isinstance(eta, datetime):
  770. eta = eta.isoformat()
  771. expires = sig.options.pop('expires', None)
  772. if expires and isinstance(expires, numbers.Real):
  773. expires = app.now() + timedelta(seconds=expires)
  774. if expires and isinstance(expires, datetime):
  775. expires = expires.isoformat()
  776. return TaskMessage(
  777. sig.task, id=sig.id, args=sig.args,
  778. kwargs=sig.kwargs,
  779. callbacks=[dict(s) for s in callbacks] if callbacks else None,
  780. errbacks=[dict(s) for s in errbacks] if errbacks else None,
  781. eta=eta,
  782. expires=expires,
  783. utc=utc,
  784. **sig.options
  785. )
  786. @contextmanager
  787. def restore_logging():
  788. outs = sys.stdout, sys.stderr, sys.__stdout__, sys.__stderr__
  789. root = logging.getLogger()
  790. level = root.level
  791. handlers = root.handlers
  792. try:
  793. yield
  794. finally:
  795. sys.stdout, sys.stderr, sys.__stdout__, sys.__stderr__ = outs
  796. root.level = level
  797. root.handlers[:] = handlers