conftest.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418
  1. from __future__ import absolute_import, unicode_literals
  2. import logging
  3. import numbers
  4. import os
  5. import pytest
  6. import sys
  7. import threading
  8. import warnings
  9. import weakref
  10. from copy import deepcopy
  11. from datetime import datetime, timedelta
  12. from functools import partial
  13. from importlib import import_module
  14. from case import Mock
  15. from case.utils import decorator
  16. from kombu import Queue
  17. from kombu.utils.imports import symbol_by_name
  18. from celery import Celery
  19. from celery.app import current_app
  20. from celery.backends.cache import CacheBackend, DummyClient
  21. try:
  22. WindowsError = WindowsError # noqa
  23. except NameError:
  24. class WindowsError(Exception):
  25. pass
  26. PYPY3 = getattr(sys, 'pypy_version_info', None) and sys.version_info[0] > 3
  27. CASE_LOG_REDIRECT_EFFECT = 'Test {0} didn\'t disable LoggingProxy for {1}'
  28. CASE_LOG_LEVEL_EFFECT = 'Test {0} modified the level of the root logger'
  29. CASE_LOG_HANDLER_EFFECT = 'Test {0} modified handlers for the root logger'
  30. CELERY_TEST_CONFIG = {
  31. #: Don't want log output when running suite.
  32. 'worker_hijack_root_logger': False,
  33. 'worker_log_color': False,
  34. 'task_default_queue': 'testcelery',
  35. 'task_default_exchange': 'testcelery',
  36. 'task_default_routing_key': 'testcelery',
  37. 'task_queues': (
  38. Queue('testcelery', routing_key='testcelery'),
  39. ),
  40. 'accept_content': ('json', 'pickle'),
  41. 'enable_utc': True,
  42. 'timezone': 'UTC',
  43. # Mongo results tests (only executed if installed and running)
  44. 'mongodb_backend_settings': {
  45. 'host': os.environ.get('MONGO_HOST') or 'localhost',
  46. 'port': os.environ.get('MONGO_PORT') or 27017,
  47. 'database': os.environ.get('MONGO_DB') or 'celery_unittests',
  48. 'taskmeta_collection': (os.environ.get('MONGO_TASKMETA_COLLECTION') or
  49. 'taskmeta_collection'),
  50. 'user': os.environ.get('MONGO_USER'),
  51. 'password': os.environ.get('MONGO_PASSWORD'),
  52. }
  53. }
  54. class Trap(object):
  55. def __getattr__(self, name):
  56. raise RuntimeError('Test depends on current_app')
  57. class UnitLogging(symbol_by_name(Celery.log_cls)):
  58. def __init__(self, *args, **kwargs):
  59. super(UnitLogging, self).__init__(*args, **kwargs)
  60. self.already_setup = True
  61. def TestApp(name=None, set_as_current=False, log=UnitLogging,
  62. broker='memory://', backend='cache+memory://', **kwargs):
  63. app = Celery(name or 'celery.tests',
  64. set_as_current=set_as_current,
  65. log=log, broker=broker, backend=backend,
  66. **kwargs)
  67. app.add_defaults(deepcopy(CELERY_TEST_CONFIG))
  68. return app
  69. def alive_threads():
  70. return [thread for thread in threading.enumerate() if thread.is_alive()]
  71. @pytest.fixture(autouse=True)
  72. def task_join_will_not_block(request):
  73. from celery import _state
  74. from celery import result
  75. prev_res_join_block = result.task_join_will_block
  76. _state.orig_task_join_will_block = _state.task_join_will_block
  77. prev_state_join_block = _state.task_join_will_block
  78. result.task_join_will_block = \
  79. _state.task_join_will_block = lambda: False
  80. _state._set_task_join_will_block(False)
  81. def fin():
  82. result.task_join_will_block = prev_res_join_block
  83. _state.task_join_will_block = prev_state_join_block
  84. _state._set_task_join_will_block(False)
  85. request.addfinalizer(fin)
  86. @pytest.fixture(scope='session', autouse=True)
  87. def record_threads_at_startup(request):
  88. try:
  89. request.session._threads_at_startup
  90. except AttributeError:
  91. request.session._threads_at_startup = alive_threads()
  92. @pytest.fixture(autouse=True)
  93. def threads_not_lingering(request):
  94. def fin():
  95. assert request.session._threads_at_startup == alive_threads()
  96. request.addfinalizer(fin)
  97. @pytest.fixture(autouse=True)
  98. def app(request):
  99. from celery import _state
  100. prev_current_app = current_app()
  101. prev_default_app = _state.default_app
  102. prev_finalizers = set(_state._on_app_finalizers)
  103. prev_apps = weakref.WeakSet(_state._apps)
  104. trap = Trap()
  105. prev_tls = _state._tls
  106. _state.set_default_app(trap)
  107. class NonTLS(object):
  108. current_app = trap
  109. _state._tls = NonTLS()
  110. app = TestApp(set_as_current=False)
  111. is_not_contained = any([
  112. not getattr(request.module, 'app_contained', True),
  113. not getattr(request.cls, 'app_contained', True),
  114. not getattr(request.function, 'app_contained', True)
  115. ])
  116. if is_not_contained:
  117. app.set_current()
  118. def fin():
  119. _state.set_default_app(prev_default_app)
  120. _state._tls = prev_tls
  121. _state._tls.current_app = prev_current_app
  122. if app is not prev_current_app:
  123. app.close()
  124. _state._on_app_finalizers = prev_finalizers
  125. _state._apps = prev_apps
  126. request.addfinalizer(fin)
  127. return app
  128. @pytest.fixture()
  129. def depends_on_current_app(app):
  130. app.set_current()
  131. @pytest.fixture(autouse=True)
  132. def test_cases_shortcuts(request, app, patching):
  133. if request.instance:
  134. @app.task
  135. def add(x, y):
  136. return x + y
  137. # IMPORTANT: We set an .app attribute for every test case class.
  138. request.instance.app = app
  139. request.instance.Celery = TestApp
  140. request.instance.assert_signal_called = assert_signal_called
  141. request.instance.task_message_from_sig = task_message_from_sig
  142. request.instance.TaskMessage = TaskMessage
  143. request.instance.TaskMessage1 = TaskMessage1
  144. request.instance.CELERY_TEST_CONFIG = dict(CELERY_TEST_CONFIG)
  145. request.instance.add = add
  146. request.instance.patching = patching
  147. def fin():
  148. request.instance.app = None
  149. request.addfinalizer(fin)
  150. @pytest.fixture(autouse=True)
  151. def zzzz_test_cases_calls_setup_teardown(request):
  152. if request.instance:
  153. # we set the .patching attribute for every test class.
  154. setup = getattr(request.instance, 'setup', None)
  155. # we also call .setup() and .teardown() after every test method.
  156. teardown = getattr(request.instance, 'teardown', None)
  157. setup and setup()
  158. teardown and request.addfinalizer(teardown)
  159. @pytest.fixture(autouse=True)
  160. def sanity_no_shutdown_flags_set(request):
  161. def fin():
  162. # Make sure no test left the shutdown flags enabled.
  163. from celery.worker import state as worker_state
  164. # check for EX_OK
  165. assert worker_state.should_stop is not False
  166. assert worker_state.should_terminate is not False
  167. # check for other true values
  168. assert not worker_state.should_stop
  169. assert not worker_state.should_terminate
  170. request.addfinalizer(fin)
  171. @pytest.fixture(autouse=True)
  172. def reset_cache_backend_state(request, app):
  173. def fin():
  174. backend = app.__dict__.get('backend')
  175. if backend is not None:
  176. if isinstance(backend, CacheBackend):
  177. if isinstance(backend.client, DummyClient):
  178. backend.client.cache.clear()
  179. backend._cache.clear()
  180. request.addfinalizer(fin)
  181. @pytest.fixture(autouse=True)
  182. def sanity_stdouts(request):
  183. def fin():
  184. from celery.utils.log import LoggingProxy
  185. assert sys.stdout
  186. assert sys.stderr
  187. assert sys.__stdout__
  188. assert sys.__stderr__
  189. this = request.node.name
  190. if isinstance(sys.stdout, (LoggingProxy, Mock)) or \
  191. isinstance(sys.__stdout__, (LoggingProxy, Mock)):
  192. raise RuntimeError(CASE_LOG_REDIRECT_EFFECT.format(this, 'stdout'))
  193. if isinstance(sys.stderr, (LoggingProxy, Mock)) or \
  194. isinstance(sys.__stderr__, (LoggingProxy, Mock)):
  195. raise RuntimeError(CASE_LOG_REDIRECT_EFFECT.format(this, 'stderr'))
  196. request.addfinalizer(fin)
  197. @pytest.fixture(autouse=True)
  198. def sanity_logging_side_effects(request):
  199. root = logging.getLogger()
  200. rootlevel = root.level
  201. roothandlers = root.handlers
  202. def fin():
  203. this = request.node.name
  204. root_now = logging.getLogger()
  205. if root_now.level != rootlevel:
  206. raise RuntimeError(CASE_LOG_LEVEL_EFFECT.format(this))
  207. if root_now.handlers != roothandlers:
  208. raise RuntimeError(CASE_LOG_HANDLER_EFFECT.format(this))
  209. request.addfinalizer(fin)
  210. def setup_session(scope='session'):
  211. using_coverage = (
  212. os.environ.get('COVER_ALL_MODULES') or '--with-coverage' in sys.argv
  213. )
  214. os.environ.update(
  215. # warn if config module not found
  216. C_WNOCONF='yes',
  217. KOMBU_DISABLE_LIMIT_PROTECTION='yes',
  218. )
  219. if using_coverage and not PYPY3:
  220. from warnings import catch_warnings
  221. with catch_warnings(record=True):
  222. import_all_modules()
  223. warnings.resetwarnings()
  224. from celery._state import set_default_app
  225. set_default_app(Trap())
  226. def teardown():
  227. # Don't want SUBDEBUG log messages at finalization.
  228. try:
  229. from multiprocessing.util import get_logger
  230. except ImportError:
  231. pass
  232. else:
  233. get_logger().setLevel(logging.WARNING)
  234. # Make sure test database is removed.
  235. import os
  236. if os.path.exists('test.db'):
  237. try:
  238. os.remove('test.db')
  239. except WindowsError:
  240. pass
  241. # Make sure there are no remaining threads at shutdown.
  242. import threading
  243. remaining_threads = [thread for thread in threading.enumerate()
  244. if thread.getName() != 'MainThread']
  245. if remaining_threads:
  246. sys.stderr.write(
  247. '\n\n**WARNING**: Remaining threads at teardown: %r...\n' % (
  248. remaining_threads))
  249. def find_distribution_modules(name=__name__, file=__file__):
  250. current_dist_depth = len(name.split('.')) - 1
  251. current_dist = os.path.join(os.path.dirname(file),
  252. *([os.pardir] * current_dist_depth))
  253. abs = os.path.abspath(current_dist)
  254. dist_name = os.path.basename(abs)
  255. for dirpath, dirnames, filenames in os.walk(abs):
  256. package = (dist_name + dirpath[len(abs):]).replace('/', '.')
  257. if '__init__.py' in filenames:
  258. yield package
  259. for filename in filenames:
  260. if filename.endswith('.py') and filename != '__init__.py':
  261. yield '.'.join([package, filename])[:-3]
  262. def import_all_modules(name=__name__, file=__file__,
  263. skip=('celery.decorators',
  264. 'celery.task')):
  265. for module in find_distribution_modules(name, file):
  266. if not module.startswith(skip):
  267. try:
  268. import_module(module)
  269. except ImportError:
  270. pass
  271. except OSError as exc:
  272. warnings.warn(UserWarning(
  273. 'Ignored error importing module {0}: {1!r}'.format(
  274. module, exc,
  275. )))
  276. @decorator
  277. def assert_signal_called(signal, **expected):
  278. handler = Mock()
  279. call_handler = partial(handler)
  280. signal.connect(call_handler)
  281. try:
  282. yield handler
  283. finally:
  284. signal.disconnect(call_handler)
  285. handler.assert_called_with(signal=signal, **expected)
  286. def TaskMessage(name, id=None, args=(), kwargs={}, callbacks=None,
  287. errbacks=None, chain=None, shadow=None, utc=None, **options):
  288. from celery import uuid
  289. from kombu.serialization import dumps
  290. id = id or uuid()
  291. message = Mock(name='TaskMessage-{0}'.format(id))
  292. message.headers = {
  293. 'id': id,
  294. 'task': name,
  295. 'shadow': shadow,
  296. }
  297. embed = {'callbacks': callbacks, 'errbacks': errbacks, 'chain': chain}
  298. message.headers.update(options)
  299. message.content_type, message.content_encoding, message.body = dumps(
  300. (args, kwargs, embed), serializer='json',
  301. )
  302. message.payload = (args, kwargs, embed)
  303. return message
  304. def TaskMessage1(name, id=None, args=(), kwargs={}, callbacks=None,
  305. errbacks=None, chain=None, **options):
  306. from celery import uuid
  307. from kombu.serialization import dumps
  308. id = id or uuid()
  309. message = Mock(name='TaskMessage-{0}'.format(id))
  310. message.headers = {}
  311. message.payload = {
  312. 'task': name,
  313. 'id': id,
  314. 'args': args,
  315. 'kwargs': kwargs,
  316. 'callbacks': callbacks,
  317. 'errbacks': errbacks,
  318. }
  319. message.payload.update(options)
  320. message.content_type, message.content_encoding, message.body = dumps(
  321. message.payload,
  322. )
  323. return message
  324. def task_message_from_sig(app, sig, utc=True, TaskMessage=TaskMessage):
  325. sig.freeze()
  326. callbacks = sig.options.pop('link', None)
  327. errbacks = sig.options.pop('link_error', None)
  328. countdown = sig.options.pop('countdown', None)
  329. if countdown:
  330. eta = app.now() + timedelta(seconds=countdown)
  331. else:
  332. eta = sig.options.pop('eta', None)
  333. if eta and isinstance(eta, datetime):
  334. eta = eta.isoformat()
  335. expires = sig.options.pop('expires', None)
  336. if expires and isinstance(expires, numbers.Real):
  337. expires = app.now() + timedelta(seconds=expires)
  338. if expires and isinstance(expires, datetime):
  339. expires = expires.isoformat()
  340. return TaskMessage(
  341. sig.task, id=sig.id, args=sig.args,
  342. kwargs=sig.kwargs,
  343. callbacks=[dict(s) for s in callbacks] if callbacks else None,
  344. errbacks=[dict(s) for s in errbacks] if errbacks else None,
  345. eta=eta,
  346. expires=expires,
  347. utc=utc,
  348. **sig.options
  349. )