conftest.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443
  1. from __future__ import absolute_import, unicode_literals
  2. import logging
  3. import numbers
  4. import os
  5. import pytest
  6. import sys
  7. import threading
  8. import warnings
  9. import weakref
  10. from copy import deepcopy
  11. from datetime import datetime, timedelta
  12. from functools import partial
  13. from importlib import import_module
  14. from case import Mock
  15. from case.utils import decorator
  16. from kombu import Queue
  17. from kombu.utils.imports import symbol_by_name
  18. from celery import Celery
  19. from celery.app import current_app
  20. from celery.backends.cache import CacheBackend, DummyClient
  21. try:
  22. WindowsError = WindowsError # noqa
  23. except NameError:
  24. class WindowsError(Exception):
  25. pass
  26. PYPY3 = getattr(sys, 'pypy_version_info', None) and sys.version_info[0] > 3
  27. CASE_LOG_REDIRECT_EFFECT = 'Test {0} didn\'t disable LoggingProxy for {1}'
  28. CASE_LOG_LEVEL_EFFECT = 'Test {0} modified the level of the root logger'
  29. CASE_LOG_HANDLER_EFFECT = 'Test {0} modified handlers for the root logger'
  30. CELERY_TEST_CONFIG = {
  31. #: Don't want log output when running suite.
  32. 'worker_hijack_root_logger': False,
  33. 'worker_log_color': False,
  34. 'task_default_queue': 'testcelery',
  35. 'task_default_exchange': 'testcelery',
  36. 'task_default_routing_key': 'testcelery',
  37. 'task_queues': (
  38. Queue('testcelery', routing_key='testcelery'),
  39. ),
  40. 'accept_content': ('json', 'pickle'),
  41. 'enable_utc': True,
  42. 'timezone': 'UTC',
  43. # Mongo results tests (only executed if installed and running)
  44. 'mongodb_backend_settings': {
  45. 'host': os.environ.get('MONGO_HOST') or 'localhost',
  46. 'port': os.environ.get('MONGO_PORT') or 27017,
  47. 'database': os.environ.get('MONGO_DB') or 'celery_unittests',
  48. 'taskmeta_collection': (os.environ.get('MONGO_TASKMETA_COLLECTION') or
  49. 'taskmeta_collection'),
  50. 'user': os.environ.get('MONGO_USER'),
  51. 'password': os.environ.get('MONGO_PASSWORD'),
  52. }
  53. }
  54. @pytest.fixture(autouse=True, scope='session')
  55. def AAA_disable_multiprocessing(request):
  56. # pytest-cov breaks if a multiprocessing.Process is started,
  57. # so disable them completely to make sure it doesn't happen.
  58. from case import patch
  59. stuff = [
  60. 'multiprocessing.Process',
  61. 'billiard.Process',
  62. 'billiard.context.Process',
  63. 'billiard.process.Process',
  64. 'billiard.process.BaseProcess',
  65. 'multiprocessing.Process',
  66. ]
  67. if sys.version_info[0] > 3:
  68. stuff.append('multiprocessing.process.BaseProcess')
  69. else:
  70. stuff.append('multiprocessing.process.Process')
  71. ctxs = [patch(s) for s in stuff]
  72. [ctx.__enter__() for ctx in ctxs]
  73. def fin():
  74. [ctx.__exit__(*sys.exc_info()) for ctx in ctxs]
  75. request.addfinalizer(fin)
  76. class Trap(object):
  77. def __getattr__(self, name):
  78. raise RuntimeError('Test depends on current_app')
  79. class UnitLogging(symbol_by_name(Celery.log_cls)):
  80. def __init__(self, *args, **kwargs):
  81. super(UnitLogging, self).__init__(*args, **kwargs)
  82. self.already_setup = True
  83. def TestApp(name=None, set_as_current=False, log=UnitLogging,
  84. broker='memory://', backend='cache+memory://', **kwargs):
  85. app = Celery(name or 'celery.tests',
  86. set_as_current=set_as_current,
  87. log=log, broker=broker, backend=backend,
  88. **kwargs)
  89. app.add_defaults(deepcopy(CELERY_TEST_CONFIG))
  90. return app
  91. def alive_threads():
  92. return [thread for thread in threading.enumerate() if thread.is_alive()]
  93. @pytest.fixture(autouse=True)
  94. def task_join_will_not_block(request):
  95. from celery import _state
  96. from celery import result
  97. prev_res_join_block = result.task_join_will_block
  98. _state.orig_task_join_will_block = _state.task_join_will_block
  99. prev_state_join_block = _state.task_join_will_block
  100. result.task_join_will_block = \
  101. _state.task_join_will_block = lambda: False
  102. _state._set_task_join_will_block(False)
  103. def fin():
  104. result.task_join_will_block = prev_res_join_block
  105. _state.task_join_will_block = prev_state_join_block
  106. _state._set_task_join_will_block(False)
  107. request.addfinalizer(fin)
  108. @pytest.fixture(scope='session', autouse=True)
  109. def record_threads_at_startup(request):
  110. try:
  111. request.session._threads_at_startup
  112. except AttributeError:
  113. request.session._threads_at_startup = alive_threads()
  114. @pytest.fixture(autouse=True)
  115. def threads_not_lingering(request):
  116. def fin():
  117. assert request.session._threads_at_startup == alive_threads()
  118. request.addfinalizer(fin)
  119. @pytest.fixture(autouse=True)
  120. def app(request):
  121. from celery import _state
  122. prev_current_app = current_app()
  123. prev_default_app = _state.default_app
  124. prev_finalizers = set(_state._on_app_finalizers)
  125. prev_apps = weakref.WeakSet(_state._apps)
  126. trap = Trap()
  127. prev_tls = _state._tls
  128. _state.set_default_app(trap)
  129. class NonTLS(object):
  130. current_app = trap
  131. _state._tls = NonTLS()
  132. app = TestApp(set_as_current=False)
  133. is_not_contained = any([
  134. not getattr(request.module, 'app_contained', True),
  135. not getattr(request.cls, 'app_contained', True),
  136. not getattr(request.function, 'app_contained', True)
  137. ])
  138. if is_not_contained:
  139. app.set_current()
  140. def fin():
  141. _state.set_default_app(prev_default_app)
  142. _state._tls = prev_tls
  143. _state._tls.current_app = prev_current_app
  144. if app is not prev_current_app:
  145. app.close()
  146. _state._on_app_finalizers = prev_finalizers
  147. _state._apps = prev_apps
  148. request.addfinalizer(fin)
  149. return app
  150. @pytest.fixture()
  151. def depends_on_current_app(app):
  152. app.set_current()
  153. @pytest.fixture(autouse=True)
  154. def test_cases_shortcuts(request, app, patching):
  155. if request.instance:
  156. @app.task
  157. def add(x, y):
  158. return x + y
  159. # IMPORTANT: We set an .app attribute for every test case class.
  160. request.instance.app = app
  161. request.instance.Celery = TestApp
  162. request.instance.assert_signal_called = assert_signal_called
  163. request.instance.task_message_from_sig = task_message_from_sig
  164. request.instance.TaskMessage = TaskMessage
  165. request.instance.TaskMessage1 = TaskMessage1
  166. request.instance.CELERY_TEST_CONFIG = dict(CELERY_TEST_CONFIG)
  167. request.instance.add = add
  168. request.instance.patching = patching
  169. def fin():
  170. request.instance.app = None
  171. request.addfinalizer(fin)
  172. @pytest.fixture(autouse=True)
  173. def zzzz_test_cases_calls_setup_teardown(request):
  174. if request.instance:
  175. # we set the .patching attribute for every test class.
  176. setup = getattr(request.instance, 'setup', None)
  177. # we also call .setup() and .teardown() after every test method.
  178. teardown = getattr(request.instance, 'teardown', None)
  179. setup and setup()
  180. teardown and request.addfinalizer(teardown)
  181. @pytest.fixture(autouse=True)
  182. def sanity_no_shutdown_flags_set(request):
  183. def fin():
  184. # Make sure no test left the shutdown flags enabled.
  185. from celery.worker import state as worker_state
  186. # check for EX_OK
  187. assert worker_state.should_stop is not False
  188. assert worker_state.should_terminate is not False
  189. # check for other true values
  190. assert not worker_state.should_stop
  191. assert not worker_state.should_terminate
  192. request.addfinalizer(fin)
  193. @pytest.fixture(autouse=True)
  194. def reset_cache_backend_state(request, app):
  195. def fin():
  196. backend = app.__dict__.get('backend')
  197. if backend is not None:
  198. if isinstance(backend, CacheBackend):
  199. if isinstance(backend.client, DummyClient):
  200. backend.client.cache.clear()
  201. backend._cache.clear()
  202. request.addfinalizer(fin)
  203. @pytest.fixture(autouse=True)
  204. def sanity_stdouts(request):
  205. def fin():
  206. from celery.utils.log import LoggingProxy
  207. assert sys.stdout
  208. assert sys.stderr
  209. assert sys.__stdout__
  210. assert sys.__stderr__
  211. this = request.node.name
  212. if isinstance(sys.stdout, (LoggingProxy, Mock)) or \
  213. isinstance(sys.__stdout__, (LoggingProxy, Mock)):
  214. raise RuntimeError(CASE_LOG_REDIRECT_EFFECT.format(this, 'stdout'))
  215. if isinstance(sys.stderr, (LoggingProxy, Mock)) or \
  216. isinstance(sys.__stderr__, (LoggingProxy, Mock)):
  217. raise RuntimeError(CASE_LOG_REDIRECT_EFFECT.format(this, 'stderr'))
  218. request.addfinalizer(fin)
  219. @pytest.fixture(autouse=True)
  220. def sanity_logging_side_effects(request):
  221. root = logging.getLogger()
  222. rootlevel = root.level
  223. roothandlers = root.handlers
  224. def fin():
  225. this = request.node.name
  226. root_now = logging.getLogger()
  227. if root_now.level != rootlevel:
  228. raise RuntimeError(CASE_LOG_LEVEL_EFFECT.format(this))
  229. if root_now.handlers != roothandlers:
  230. raise RuntimeError(CASE_LOG_HANDLER_EFFECT.format(this))
  231. request.addfinalizer(fin)
  232. def setup_session(scope='session'):
  233. using_coverage = (
  234. os.environ.get('COVER_ALL_MODULES') or '--with-coverage' in sys.argv
  235. )
  236. os.environ.update(
  237. # warn if config module not found
  238. C_WNOCONF='yes',
  239. KOMBU_DISABLE_LIMIT_PROTECTION='yes',
  240. )
  241. if using_coverage and not PYPY3:
  242. from warnings import catch_warnings
  243. with catch_warnings(record=True):
  244. import_all_modules()
  245. warnings.resetwarnings()
  246. from celery._state import set_default_app
  247. set_default_app(Trap())
  248. def teardown():
  249. # Don't want SUBDEBUG log messages at finalization.
  250. try:
  251. from multiprocessing.util import get_logger
  252. except ImportError:
  253. pass
  254. else:
  255. get_logger().setLevel(logging.WARNING)
  256. # Make sure test database is removed.
  257. import os
  258. if os.path.exists('test.db'):
  259. try:
  260. os.remove('test.db')
  261. except WindowsError:
  262. pass
  263. # Make sure there are no remaining threads at shutdown.
  264. import threading
  265. remaining_threads = [thread for thread in threading.enumerate()
  266. if thread.getName() != 'MainThread']
  267. if remaining_threads:
  268. sys.stderr.write(
  269. '\n\n**WARNING**: Remaining threads at teardown: %r...\n' % (
  270. remaining_threads))
  271. def find_distribution_modules(name=__name__, file=__file__):
  272. current_dist_depth = len(name.split('.')) - 1
  273. current_dist = os.path.join(os.path.dirname(file),
  274. *([os.pardir] * current_dist_depth))
  275. abs = os.path.abspath(current_dist)
  276. dist_name = os.path.basename(abs)
  277. for dirpath, dirnames, filenames in os.walk(abs):
  278. package = (dist_name + dirpath[len(abs):]).replace('/', '.')
  279. if '__init__.py' in filenames:
  280. yield package
  281. for filename in filenames:
  282. if filename.endswith('.py') and filename != '__init__.py':
  283. yield '.'.join([package, filename])[:-3]
  284. def import_all_modules(name=__name__, file=__file__,
  285. skip=('celery.decorators',
  286. 'celery.task')):
  287. for module in find_distribution_modules(name, file):
  288. if not module.startswith(skip):
  289. try:
  290. import_module(module)
  291. except ImportError:
  292. pass
  293. except OSError as exc:
  294. warnings.warn(UserWarning(
  295. 'Ignored error importing module {0}: {1!r}'.format(
  296. module, exc,
  297. )))
  298. @decorator
  299. def assert_signal_called(signal, **expected):
  300. handler = Mock()
  301. call_handler = partial(handler)
  302. signal.connect(call_handler)
  303. try:
  304. yield handler
  305. finally:
  306. signal.disconnect(call_handler)
  307. handler.assert_called_with(signal=signal, **expected)
  308. def TaskMessage(name, id=None, args=(), kwargs={}, callbacks=None,
  309. errbacks=None, chain=None, shadow=None, utc=None, **options):
  310. from celery import uuid
  311. from kombu.serialization import dumps
  312. id = id or uuid()
  313. message = Mock(name='TaskMessage-{0}'.format(id))
  314. message.headers = {
  315. 'id': id,
  316. 'task': name,
  317. 'shadow': shadow,
  318. }
  319. embed = {'callbacks': callbacks, 'errbacks': errbacks, 'chain': chain}
  320. message.headers.update(options)
  321. message.content_type, message.content_encoding, message.body = dumps(
  322. (args, kwargs, embed), serializer='json',
  323. )
  324. message.payload = (args, kwargs, embed)
  325. return message
  326. def TaskMessage1(name, id=None, args=(), kwargs={}, callbacks=None,
  327. errbacks=None, chain=None, **options):
  328. from celery import uuid
  329. from kombu.serialization import dumps
  330. id = id or uuid()
  331. message = Mock(name='TaskMessage-{0}'.format(id))
  332. message.headers = {}
  333. message.payload = {
  334. 'task': name,
  335. 'id': id,
  336. 'args': args,
  337. 'kwargs': kwargs,
  338. 'callbacks': callbacks,
  339. 'errbacks': errbacks,
  340. }
  341. message.payload.update(options)
  342. message.content_type, message.content_encoding, message.body = dumps(
  343. message.payload,
  344. )
  345. return message
  346. def task_message_from_sig(app, sig, utc=True, TaskMessage=TaskMessage):
  347. sig.freeze()
  348. callbacks = sig.options.pop('link', None)
  349. errbacks = sig.options.pop('link_error', None)
  350. countdown = sig.options.pop('countdown', None)
  351. if countdown:
  352. eta = app.now() + timedelta(seconds=countdown)
  353. else:
  354. eta = sig.options.pop('eta', None)
  355. if eta and isinstance(eta, datetime):
  356. eta = eta.isoformat()
  357. expires = sig.options.pop('expires', None)
  358. if expires and isinstance(expires, numbers.Real):
  359. expires = app.now() + timedelta(seconds=expires)
  360. if expires and isinstance(expires, datetime):
  361. expires = expires.isoformat()
  362. return TaskMessage(
  363. sig.task, id=sig.id, args=sig.args,
  364. kwargs=sig.kwargs,
  365. callbacks=[dict(s) for s in callbacks] if callbacks else None,
  366. errbacks=[dict(s) for s in errbacks] if errbacks else None,
  367. eta=eta,
  368. expires=expires,
  369. utc=utc,
  370. **sig.options
  371. )