conftest.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. from __future__ import absolute_import, unicode_literals
  2. import logging
  3. import os
  4. import pytest
  5. import sys
  6. import threading
  7. import warnings
  8. from functools import partial
  9. from importlib import import_module
  10. from case import Mock
  11. from case.utils import decorator
  12. from kombu import Queue
  13. from celery.backends.cache import CacheBackend, DummyClient
  14. from celery.contrib.testing.app import Trap, TestApp
  15. from celery.contrib.testing.mocks import (
  16. TaskMessage, TaskMessage1, task_message_from_sig,
  17. )
  18. try:
  19. WindowsError = WindowsError # noqa
  20. except NameError:
  21. class WindowsError(Exception):
  22. pass
  23. PYPY3 = getattr(sys, 'pypy_version_info', None) and sys.version_info[0] > 3
  24. CASE_LOG_REDIRECT_EFFECT = 'Test {0} didn\'t disable LoggingProxy for {1}'
  25. CASE_LOG_LEVEL_EFFECT = 'Test {0} modified the level of the root logger'
  26. CASE_LOG_HANDLER_EFFECT = 'Test {0} modified handlers for the root logger'
  27. @pytest.fixture(scope='session')
  28. def celery_config():
  29. return {
  30. 'broker_url': 'memory://',
  31. 'result_backend': 'cache+memory://',
  32. 'task_default_queue': 'testcelery',
  33. 'task_default_exchange': 'testcelery',
  34. 'task_default_routing_key': 'testcelery',
  35. 'task_queues': (
  36. Queue('testcelery', routing_key='testcelery'),
  37. ),
  38. 'accept_content': ('json', 'pickle'),
  39. # Mongo results tests (only executed if installed and running)
  40. 'mongodb_backend_settings': {
  41. 'host': os.environ.get('MONGO_HOST') or 'localhost',
  42. 'port': os.environ.get('MONGO_PORT') or 27017,
  43. 'database': os.environ.get('MONGO_DB') or 'celery_unittests',
  44. 'taskmeta_collection': (
  45. os.environ.get('MONGO_TASKMETA_COLLECTION') or
  46. 'taskmeta_collection'
  47. ),
  48. 'user': os.environ.get('MONGO_USER'),
  49. 'password': os.environ.get('MONGO_PASSWORD'),
  50. }
  51. }
  52. @pytest.fixture(scope='session')
  53. def use_celery_app_trap():
  54. return True
  55. @pytest.fixture(autouse=True)
  56. def reset_cache_backend_state(celery_app):
  57. """Fixture that resets the internal state of the cache result backend."""
  58. yield
  59. backend = celery_app.__dict__.get('backend')
  60. if backend is not None:
  61. if isinstance(backend, CacheBackend):
  62. if isinstance(backend.client, DummyClient):
  63. backend.client.cache.clear()
  64. backend._cache.clear()
  65. @decorator
  66. def assert_signal_called(signal, **expected):
  67. """Context that verifes signal is called before exiting."""
  68. handler = Mock()
  69. call_handler = partial(handler)
  70. signal.connect(call_handler)
  71. try:
  72. yield handler
  73. finally:
  74. signal.disconnect(call_handler)
  75. handler.assert_called_with(signal=signal, **expected)
  76. @pytest.fixture
  77. def app(celery_app):
  78. yield celery_app
  79. @pytest.fixture(autouse=True, scope='session')
  80. def AAA_disable_multiprocessing():
  81. # pytest-cov breaks if a multiprocessing.Process is started,
  82. # so disable them completely to make sure it doesn't happen.
  83. from case import patch
  84. stuff = [
  85. 'multiprocessing.Process',
  86. 'billiard.Process',
  87. 'billiard.context.Process',
  88. 'billiard.process.Process',
  89. 'billiard.process.BaseProcess',
  90. 'multiprocessing.Process',
  91. ]
  92. ctxs = [patch(s) for s in stuff]
  93. [ctx.__enter__() for ctx in ctxs]
  94. yield
  95. [ctx.__exit__(*sys.exc_info()) for ctx in ctxs]
  96. def alive_threads():
  97. return [thread for thread in threading.enumerate() if thread.is_alive()]
  98. @pytest.fixture(autouse=True)
  99. def task_join_will_not_block():
  100. from celery import _state
  101. from celery import result
  102. prev_res_join_block = result.task_join_will_block
  103. _state.orig_task_join_will_block = _state.task_join_will_block
  104. prev_state_join_block = _state.task_join_will_block
  105. result.task_join_will_block = \
  106. _state.task_join_will_block = lambda: False
  107. _state._set_task_join_will_block(False)
  108. yield
  109. result.task_join_will_block = prev_res_join_block
  110. _state.task_join_will_block = prev_state_join_block
  111. _state._set_task_join_will_block(False)
  112. @pytest.fixture(scope='session', autouse=True)
  113. def record_threads_at_startup(request):
  114. try:
  115. request.session._threads_at_startup
  116. except AttributeError:
  117. request.session._threads_at_startup = alive_threads()
  118. @pytest.fixture(autouse=True)
  119. def threads_not_lingering(request):
  120. yield
  121. assert request.session._threads_at_startup == alive_threads()
  122. @pytest.fixture(autouse=True)
  123. def AAA_reset_CELERY_LOADER_env():
  124. yield
  125. assert not os.environ.get('CELERY_LOADER')
  126. @pytest.fixture(autouse=True)
  127. def test_cases_shortcuts(request, app, patching, celery_config):
  128. if request.instance:
  129. @app.task
  130. def add(x, y):
  131. return x + y
  132. # IMPORTANT: We set an .app attribute for every test case class.
  133. request.instance.app = app
  134. request.instance.Celery = TestApp
  135. request.instance.assert_signal_called = assert_signal_called
  136. request.instance.task_message_from_sig = task_message_from_sig
  137. request.instance.TaskMessage = TaskMessage
  138. request.instance.TaskMessage1 = TaskMessage1
  139. request.instance.CELERY_TEST_CONFIG = celery_config
  140. request.instance.add = add
  141. request.instance.patching = patching
  142. yield
  143. if request.instance:
  144. request.instance.app = None
  145. @pytest.fixture(autouse=True)
  146. def sanity_no_shutdown_flags_set():
  147. yield
  148. # Make sure no test left the shutdown flags enabled.
  149. from celery.worker import state as worker_state
  150. # check for EX_OK
  151. assert worker_state.should_stop is not False
  152. assert worker_state.should_terminate is not False
  153. # check for other true values
  154. assert not worker_state.should_stop
  155. assert not worker_state.should_terminate
  156. @pytest.fixture(autouse=True)
  157. def sanity_stdouts(request):
  158. yield
  159. from celery.utils.log import LoggingProxy
  160. assert sys.stdout
  161. assert sys.stderr
  162. assert sys.__stdout__
  163. assert sys.__stderr__
  164. this = request.node.name
  165. if isinstance(sys.stdout, (LoggingProxy, Mock)) or \
  166. isinstance(sys.__stdout__, (LoggingProxy, Mock)):
  167. raise RuntimeError(CASE_LOG_REDIRECT_EFFECT.format(this, 'stdout'))
  168. if isinstance(sys.stderr, (LoggingProxy, Mock)) or \
  169. isinstance(sys.__stderr__, (LoggingProxy, Mock)):
  170. raise RuntimeError(CASE_LOG_REDIRECT_EFFECT.format(this, 'stderr'))
  171. @pytest.fixture(autouse=True)
  172. def sanity_logging_side_effects(request):
  173. root = logging.getLogger()
  174. rootlevel = root.level
  175. roothandlers = root.handlers
  176. yield
  177. this = request.node.name
  178. root_now = logging.getLogger()
  179. if root_now.level != rootlevel:
  180. raise RuntimeError(CASE_LOG_LEVEL_EFFECT.format(this))
  181. if root_now.handlers != roothandlers:
  182. raise RuntimeError(CASE_LOG_HANDLER_EFFECT.format(this))
  183. def setup_session(scope='session'):
  184. using_coverage = (
  185. os.environ.get('COVER_ALL_MODULES') or '--with-coverage' in sys.argv
  186. )
  187. os.environ.update(
  188. # warn if config module not found
  189. C_WNOCONF='yes',
  190. KOMBU_DISABLE_LIMIT_PROTECTION='yes',
  191. )
  192. if using_coverage and not PYPY3:
  193. from warnings import catch_warnings
  194. with catch_warnings(record=True):
  195. import_all_modules()
  196. warnings.resetwarnings()
  197. from celery._state import set_default_app
  198. set_default_app(Trap())
  199. def teardown():
  200. # Don't want SUBDEBUG log messages at finalization.
  201. try:
  202. from multiprocessing.util import get_logger
  203. except ImportError:
  204. pass
  205. else:
  206. get_logger().setLevel(logging.WARNING)
  207. # Make sure test database is removed.
  208. import os
  209. if os.path.exists('test.db'):
  210. try:
  211. os.remove('test.db')
  212. except WindowsError:
  213. pass
  214. # Make sure there are no remaining threads at shutdown.
  215. import threading
  216. remaining_threads = [thread for thread in threading.enumerate()
  217. if thread.getName() != 'MainThread']
  218. if remaining_threads:
  219. sys.stderr.write(
  220. '\n\n**WARNING**: Remaining threads at teardown: %r...\n' % (
  221. remaining_threads))
  222. def find_distribution_modules(name=__name__, file=__file__):
  223. current_dist_depth = len(name.split('.')) - 1
  224. current_dist = os.path.join(os.path.dirname(file),
  225. *([os.pardir] * current_dist_depth))
  226. abs = os.path.abspath(current_dist)
  227. dist_name = os.path.basename(abs)
  228. for dirpath, dirnames, filenames in os.walk(abs):
  229. package = (dist_name + dirpath[len(abs):]).replace('/', '.')
  230. if '__init__.py' in filenames:
  231. yield package
  232. for filename in filenames:
  233. if filename.endswith('.py') and filename != '__init__.py':
  234. yield '.'.join([package, filename])[:-3]
  235. def import_all_modules(name=__name__, file=__file__,
  236. skip=('celery.decorators',
  237. 'celery.task')):
  238. for module in find_distribution_modules(name, file):
  239. if not module.startswith(skip):
  240. try:
  241. import_module(module)
  242. except ImportError:
  243. pass
  244. except OSError as exc:
  245. warnings.warn(UserWarning(
  246. 'Ignored error importing module {0}: {1!r}'.format(
  247. module, exc,
  248. )))