case.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. from __future__ import absolute_import, unicode_literals
  2. import importlib
  3. import inspect
  4. import logging
  5. import numbers
  6. import os
  7. import sys
  8. import threading
  9. from copy import deepcopy
  10. from datetime import datetime, timedelta
  11. from functools import partial
  12. from kombu import Queue
  13. from kombu.utils import symbol_by_name
  14. from vine.utils import wraps
  15. from celery import Celery
  16. from celery.app import current_app
  17. from celery.backends.cache import CacheBackend, DummyClient
  18. from celery.exceptions import CDeprecationWarning, CPendingDeprecationWarning
  19. from celery.utils.imports import qualname
  20. from case import (
  21. ANY, ContextMock, MagicMock, Mock, call, mock, skip, patch, sentinel,
  22. )
  23. from case import Case as _Case
  24. from case.utils import decorator
  25. __all__ = [
  26. 'ANY', 'ContextMock', 'MagicMock', 'Mock',
  27. 'call', 'mock', 'skip', 'patch', 'sentinel',
  28. 'AppCase', 'TaskMessage', 'TaskMessage1',
  29. 'depends_on_current_app', 'assert_signal_called', 'task_message_from_sig',
  30. ]
  31. CASE_REDEFINES_SETUP = """\
  32. {name} (subclass of AppCase) redefines private "setUp", should be: "setup"\
  33. """
  34. CASE_REDEFINES_TEARDOWN = """\
  35. {name} (subclass of AppCase) redefines private "tearDown", \
  36. should be: "teardown"\
  37. """
  38. CASE_LOG_REDIRECT_EFFECT = """\
  39. Test {0} did not disable LoggingProxy for {1}\
  40. """
  41. CASE_LOG_LEVEL_EFFECT = """\
  42. Test {0} Modified the level of the root logger\
  43. """
  44. CASE_LOG_HANDLER_EFFECT = """\
  45. Test {0} Modified handlers for the root logger\
  46. """
  47. CELERY_TEST_CONFIG = {
  48. #: Don't want log output when running suite.
  49. 'worker_hijack_root_logger': False,
  50. 'worker_log_color': False,
  51. 'task_default_queue': 'testcelery',
  52. 'task_default_exchange': 'testcelery',
  53. 'task_default_routing_key': 'testcelery',
  54. 'task_queues': (
  55. Queue('testcelery', routing_key='testcelery'),
  56. ),
  57. 'accept_content': ('json', 'pickle'),
  58. 'enable_utc': True,
  59. 'timezone': 'UTC',
  60. # Mongo results tests (only executed if installed and running)
  61. 'mongodb_backend_settings': {
  62. 'host': os.environ.get('MONGO_HOST') or 'localhost',
  63. 'port': os.environ.get('MONGO_PORT') or 27017,
  64. 'database': os.environ.get('MONGO_DB') or 'celery_unittests',
  65. 'taskmeta_collection': (os.environ.get('MONGO_TASKMETA_COLLECTION') or
  66. 'taskmeta_collection'),
  67. 'user': os.environ.get('MONGO_USER'),
  68. 'password': os.environ.get('MONGO_PASSWORD'),
  69. }
  70. }
  71. class Case(_Case):
  72. DeprecationWarning = CDeprecationWarning
  73. PendingDeprecationWarning = CPendingDeprecationWarning
  74. class Trap:
  75. def __getattr__(self, name):
  76. raise RuntimeError('Test depends on current_app')
  77. class UnitLogging(symbol_by_name(Celery.log_cls)):
  78. def __init__(self, *args, **kwargs):
  79. super(UnitLogging, self).__init__(*args, **kwargs)
  80. self.already_setup = True
  81. def UnitApp(name=None, set_as_current=False, log=UnitLogging,
  82. broker='memory://', backend='cache+memory://', **kwargs):
  83. app = Celery(name or 'celery.tests',
  84. set_as_current=set_as_current,
  85. log=log, broker=broker, backend=backend,
  86. **kwargs)
  87. app.add_defaults(deepcopy(CELERY_TEST_CONFIG))
  88. return app
  89. def alive_threads():
  90. return [thread for thread in threading.enumerate() if thread.is_alive()]
  91. def depends_on_current_app(fun):
  92. if inspect.isclass(fun):
  93. fun.contained = False
  94. else:
  95. @wraps(fun)
  96. def __inner(self, *args, **kwargs):
  97. self.app.set_current()
  98. return fun(self, *args, **kwargs)
  99. return __inner
  100. class AppCase(Case):
  101. contained = True
  102. _threads_at_startup = [None]
  103. def __init__(self, *args, **kwargs):
  104. super(AppCase, self).__init__(*args, **kwargs)
  105. setUp = self.__class__.__dict__.get('setUp')
  106. tearDown = self.__class__.__dict__.get('tearDown')
  107. if setUp and not hasattr(setUp, '__wrapped__'):
  108. raise RuntimeError(
  109. CASE_REDEFINES_SETUP.format(name=qualname(self)),
  110. )
  111. if tearDown and not hasattr(tearDown, '__wrapped__'):
  112. raise RuntimeError(
  113. CASE_REDEFINES_TEARDOWN.format(name=qualname(self)),
  114. )
  115. def Celery(self, *args, **kwargs):
  116. return UnitApp(*args, **kwargs)
  117. def threads_at_startup(self):
  118. if self._threads_at_startup[0] is None:
  119. self._threads_at_startup[0] = alive_threads()
  120. return self._threads_at_startup[0]
  121. def setUp(self):
  122. self._threads_at_setup = self.threads_at_startup()
  123. from celery import _state
  124. from celery import result
  125. self._prev_res_join_block = result.task_join_will_block
  126. self._prev_state_join_block = _state.task_join_will_block
  127. result.task_join_will_block = \
  128. _state.task_join_will_block = lambda: False
  129. self._current_app = current_app()
  130. self._default_app = _state.default_app
  131. trap = Trap()
  132. self._prev_tls = _state._tls
  133. _state.set_default_app(trap)
  134. class NonTLS:
  135. current_app = trap
  136. _state._tls = NonTLS()
  137. self.app = self.Celery(set_as_current=False)
  138. if not self.contained:
  139. self.app.set_current()
  140. root = logging.getLogger()
  141. self.__rootlevel = root.level
  142. self.__roothandlers = root.handlers
  143. _state._set_task_join_will_block(False)
  144. try:
  145. self.setup()
  146. except:
  147. self._teardown_app()
  148. raise
  149. def _teardown_app(self):
  150. from celery import _state
  151. from celery import result
  152. from celery.utils.log import LoggingProxy
  153. assert sys.stdout
  154. assert sys.stderr
  155. assert sys.__stdout__
  156. assert sys.__stderr__
  157. this = self._get_test_name()
  158. result.task_join_will_block = self._prev_res_join_block
  159. _state.task_join_will_block = self._prev_state_join_block
  160. if isinstance(sys.stdout, (LoggingProxy, Mock)) or \
  161. isinstance(sys.__stdout__, (LoggingProxy, Mock)):
  162. raise RuntimeError(CASE_LOG_REDIRECT_EFFECT.format(this, 'stdout'))
  163. if isinstance(sys.stderr, (LoggingProxy, Mock)) or \
  164. isinstance(sys.__stderr__, (LoggingProxy, Mock)):
  165. raise RuntimeError(CASE_LOG_REDIRECT_EFFECT.format(this, 'stderr'))
  166. backend = self.app.__dict__.get('backend')
  167. if backend is not None:
  168. if isinstance(backend, CacheBackend):
  169. if isinstance(backend.client, DummyClient):
  170. backend.client.cache.clear()
  171. backend._cache.clear()
  172. from celery import _state
  173. _state._set_task_join_will_block(False)
  174. _state.set_default_app(self._default_app)
  175. _state._tls = self._prev_tls
  176. _state._tls.current_app = self._current_app
  177. if self.app is not self._current_app:
  178. self.app.close()
  179. self.app = None
  180. self.assertEqual(self._threads_at_setup, alive_threads())
  181. # Make sure no test left the shutdown flags enabled.
  182. from celery.worker import state as worker_state
  183. # check for EX_OK
  184. self.assertIsNot(worker_state.should_stop, False)
  185. self.assertIsNot(worker_state.should_terminate, False)
  186. # check for other true values
  187. self.assertFalse(worker_state.should_stop)
  188. self.assertFalse(worker_state.should_terminate)
  189. def _get_test_name(self):
  190. return '.'.join([self.__class__.__name__, self._testMethodName])
  191. def tearDown(self):
  192. try:
  193. self.teardown()
  194. finally:
  195. self._teardown_app()
  196. self.assert_no_logging_side_effect()
  197. def assert_no_logging_side_effect(self):
  198. this = self._get_test_name()
  199. root = logging.getLogger()
  200. if root.level != self.__rootlevel:
  201. raise RuntimeError(CASE_LOG_LEVEL_EFFECT.format(this))
  202. if root.handlers != self.__roothandlers:
  203. raise RuntimeError(CASE_LOG_HANDLER_EFFECT.format(this))
  204. def assert_signal_called(self, signal, **expected):
  205. return assert_signal_called(signal, **expected)
  206. def setup(self):
  207. pass
  208. def teardown(self):
  209. pass
  210. @decorator
  211. def assert_signal_called(signal, **expected):
  212. handler = Mock()
  213. call_handler = partial(handler)
  214. signal.connect(call_handler)
  215. try:
  216. yield handler
  217. finally:
  218. signal.disconnect(call_handler)
  219. handler.assert_called_with(signal=signal, **expected)
  220. def TaskMessage(name, id=None, args=(), kwargs={}, callbacks=None,
  221. errbacks=None, chain=None, shadow=None, utc=None, **options):
  222. from celery import uuid
  223. from kombu.serialization import dumps
  224. id = id or uuid()
  225. message = Mock(name='TaskMessage-{0}'.format(id))
  226. message.headers = {
  227. 'id': id,
  228. 'task': name,
  229. 'shadow': shadow,
  230. }
  231. embed = {'callbacks': callbacks, 'errbacks': errbacks, 'chain': chain}
  232. message.headers.update(options)
  233. message.content_type, message.content_encoding, message.body = dumps(
  234. (args, kwargs, embed), serializer='json',
  235. )
  236. message.payload = (args, kwargs, embed)
  237. return message
  238. def TaskMessage1(name, id=None, args=(), kwargs={}, callbacks=None,
  239. errbacks=None, chain=None, **options):
  240. from celery import uuid
  241. from kombu.serialization import dumps
  242. id = id or uuid()
  243. message = Mock(name='TaskMessage-{0}'.format(id))
  244. message.headers = {}
  245. message.payload = {
  246. 'task': name,
  247. 'id': id,
  248. 'args': args,
  249. 'kwargs': kwargs,
  250. 'callbacks': callbacks,
  251. 'errbacks': errbacks,
  252. }
  253. message.payload.update(options)
  254. message.content_type, message.content_encoding, message.body = dumps(
  255. message.payload,
  256. )
  257. return message
  258. def task_message_from_sig(app, sig, utc=True, TaskMessage=TaskMessage):
  259. sig.freeze()
  260. callbacks = sig.options.pop('link', None)
  261. errbacks = sig.options.pop('link_error', None)
  262. countdown = sig.options.pop('countdown', None)
  263. if countdown:
  264. eta = app.now() + timedelta(seconds=countdown)
  265. else:
  266. eta = sig.options.pop('eta', None)
  267. if eta and isinstance(eta, datetime):
  268. eta = eta.isoformat()
  269. expires = sig.options.pop('expires', None)
  270. if expires and isinstance(expires, numbers.Real):
  271. expires = app.now() + timedelta(seconds=expires)
  272. if expires and isinstance(expires, datetime):
  273. expires = expires.isoformat()
  274. return TaskMessage(
  275. sig.task, id=sig.id, args=sig.args,
  276. kwargs=sig.kwargs,
  277. callbacks=[dict(s) for s in callbacks] if callbacks else None,
  278. errbacks=[dict(s) for s in errbacks] if errbacks else None,
  279. eta=eta,
  280. expires=expires,
  281. utc=utc,
  282. **sig.options
  283. )
  284. def _old_patch(module, name, mocked):
  285. module = importlib.import_module(module)
  286. def _patch(fun):
  287. @wraps(fun)
  288. def __patched(*args, **kwargs):
  289. prev = getattr(module, name)
  290. setattr(module, name, mocked)
  291. try:
  292. return fun(*args, **kwargs)
  293. finally:
  294. setattr(module, name, prev)
  295. return __patched
  296. return _patch