test_trace.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. from __future__ import absolute_import, unicode_literals
  2. import pytest
  3. from case import Mock, patch
  4. from kombu.exceptions import EncodeError
  5. from celery import group, uuid
  6. from celery import signals
  7. from celery import states
  8. from celery.app.task import Context
  9. from celery.exceptions import Ignore, Retry, Reject
  10. from celery.app.trace import (
  11. TraceInfo,
  12. build_tracer,
  13. get_log_policy,
  14. get_task_name,
  15. log_policy_reject,
  16. log_policy_ignore,
  17. log_policy_internal,
  18. log_policy_expected,
  19. log_policy_unexpected,
  20. trace_task,
  21. _trace_task_ret,
  22. _fast_trace_task,
  23. setup_worker_optimizations,
  24. reset_worker_optimizations,
  25. )
  26. def trace(app, task, args=(), kwargs={},
  27. propagate=False, eager=True, request=None, **opts):
  28. t = build_tracer(task.name, task,
  29. eager=eager, propagate=propagate, app=app, **opts)
  30. ret = t('id-1', args, kwargs, request)
  31. return ret.retval, ret.info
  32. class TraceCase:
  33. def setup(self):
  34. @self.app.task(shared=False)
  35. def add(x, y):
  36. return x + y
  37. self.add = add
  38. @self.app.task(shared=False, ignore_result=True)
  39. def add_cast(x, y):
  40. return x + y
  41. self.add_cast = add_cast
  42. @self.app.task(shared=False)
  43. def raises(exc):
  44. raise exc
  45. self.raises = raises
  46. def trace(self, *args, **kwargs):
  47. return trace(self.app, *args, **kwargs)
  48. class test_trace(TraceCase):
  49. def test_trace_successful(self):
  50. retval, info = self.trace(self.add, (2, 2), {})
  51. assert info is None
  52. assert retval == 4
  53. def test_trace_on_success(self):
  54. @self.app.task(shared=False, on_success=Mock())
  55. def add_with_success(x, y):
  56. return x + y
  57. self.trace(add_with_success, (2, 2), {})
  58. add_with_success.on_success.assert_called()
  59. def test_get_log_policy(self):
  60. einfo = Mock(name='einfo')
  61. einfo.internal = False
  62. assert get_log_policy(self.add, einfo, Reject()) is log_policy_reject
  63. assert get_log_policy(self.add, einfo, Ignore()) is log_policy_ignore
  64. self.add.throws = (TypeError,)
  65. assert (get_log_policy(self.add, einfo, KeyError()) is
  66. log_policy_unexpected)
  67. assert (get_log_policy(self.add, einfo, TypeError()) is
  68. log_policy_expected)
  69. einfo2 = Mock(name='einfo2')
  70. einfo2.internal = True
  71. assert (get_log_policy(self.add, einfo2, KeyError()) is
  72. log_policy_internal)
  73. def test_get_task_name(self):
  74. assert get_task_name(Context({}), 'default') == 'default'
  75. assert get_task_name(Context({'shadow': None}), 'default') == 'default'
  76. assert get_task_name(Context({'shadow': ''}), 'default') == 'default'
  77. assert get_task_name(Context({'shadow': 'test'}), 'default') == 'test'
  78. def test_trace_after_return(self):
  79. @self.app.task(shared=False, after_return=Mock())
  80. def add_with_after_return(x, y):
  81. return x + y
  82. self.trace(add_with_after_return, (2, 2), {})
  83. add_with_after_return.after_return.assert_called()
  84. def test_with_prerun_receivers(self):
  85. on_prerun = Mock()
  86. signals.task_prerun.connect(on_prerun)
  87. try:
  88. self.trace(self.add, (2, 2), {})
  89. on_prerun.assert_called()
  90. finally:
  91. signals.task_prerun.receivers[:] = []
  92. def test_with_postrun_receivers(self):
  93. on_postrun = Mock()
  94. signals.task_postrun.connect(on_postrun)
  95. try:
  96. self.trace(self.add, (2, 2), {})
  97. on_postrun.assert_called()
  98. finally:
  99. signals.task_postrun.receivers[:] = []
  100. def test_with_success_receivers(self):
  101. on_success = Mock()
  102. signals.task_success.connect(on_success)
  103. try:
  104. self.trace(self.add, (2, 2), {})
  105. on_success.assert_called()
  106. finally:
  107. signals.task_success.receivers[:] = []
  108. def test_when_chord_part(self):
  109. @self.app.task(shared=False)
  110. def add(x, y):
  111. return x + y
  112. add.backend = Mock()
  113. request = {'chord': uuid()}
  114. self.trace(add, (2, 2), {}, request=request)
  115. add.backend.mark_as_done.assert_called()
  116. args, kwargs = add.backend.mark_as_done.call_args
  117. assert args[0] == 'id-1'
  118. assert args[1] == 4
  119. assert args[2].chord == request['chord']
  120. assert not args[3]
  121. def test_when_backend_cleanup_raises(self):
  122. @self.app.task(shared=False)
  123. def add(x, y):
  124. return x + y
  125. add.backend = Mock(name='backend')
  126. add.backend.process_cleanup.side_effect = KeyError()
  127. self.trace(add, (2, 2), {}, eager=False)
  128. add.backend.process_cleanup.assert_called_with()
  129. add.backend.process_cleanup.side_effect = MemoryError()
  130. with pytest.raises(MemoryError):
  131. self.trace(add, (2, 2), {}, eager=False)
  132. def test_when_Ignore(self):
  133. @self.app.task(shared=False)
  134. def ignored():
  135. raise Ignore()
  136. retval, info = self.trace(ignored, (), {})
  137. assert info.state == states.IGNORED
  138. def test_when_Reject(self):
  139. @self.app.task(shared=False)
  140. def rejecting():
  141. raise Reject()
  142. retval, info = self.trace(rejecting, (), {})
  143. assert info.state == states.REJECTED
  144. def test_backend_cleanup_raises(self):
  145. self.add.backend.process_cleanup = Mock()
  146. self.add.backend.process_cleanup.side_effect = RuntimeError()
  147. self.trace(self.add, (2, 2), {})
  148. @patch('celery.canvas.maybe_signature')
  149. def test_callbacks__scalar(self, maybe_signature):
  150. sig = Mock(name='sig')
  151. request = {'callbacks': [sig], 'root_id': 'root'}
  152. maybe_signature.return_value = sig
  153. retval, _ = self.trace(self.add, (2, 2), {}, request=request)
  154. sig.apply_async.assert_called_with(
  155. (4,), parent_id='id-1', root_id='root',
  156. )
  157. @patch('celery.canvas.maybe_signature')
  158. def test_chain_proto2(self, maybe_signature):
  159. sig = Mock(name='sig')
  160. sig2 = Mock(name='sig2')
  161. request = {'chain': [sig2, sig], 'root_id': 'root'}
  162. maybe_signature.return_value = sig
  163. retval, _ = self.trace(self.add, (2, 2), {}, request=request)
  164. sig.apply_async.assert_called_with(
  165. (4, ), parent_id='id-1', root_id='root',
  166. chain=[sig2],
  167. )
  168. @patch('celery.canvas.maybe_signature')
  169. def test_callbacks__EncodeError(self, maybe_signature):
  170. sig = Mock(name='sig')
  171. request = {'callbacks': [sig], 'root_id': 'root'}
  172. maybe_signature.return_value = sig
  173. sig.apply_async.side_effect = EncodeError()
  174. retval, einfo = self.trace(self.add, (2, 2), {}, request=request)
  175. assert einfo.state == states.FAILURE
  176. @patch('celery.canvas.maybe_signature')
  177. @patch('celery.app.trace.group.apply_async')
  178. def test_callbacks__sigs(self, group_, maybe_signature):
  179. sig1 = Mock(name='sig')
  180. sig2 = Mock(name='sig2')
  181. sig3 = group([Mock(name='g1'), Mock(name='g2')], app=self.app)
  182. sig3.apply_async = Mock(name='gapply')
  183. request = {'callbacks': [sig1, sig3, sig2], 'root_id': 'root'}
  184. def passt(s, *args, **kwargs):
  185. return s
  186. maybe_signature.side_effect = passt
  187. retval, _ = self.trace(self.add, (2, 2), {}, request=request)
  188. group_.assert_called_with(
  189. (4,), parent_id='id-1', root_id='root',
  190. )
  191. sig3.apply_async.assert_called_with(
  192. (4,), parent_id='id-1', root_id='root',
  193. )
  194. @patch('celery.canvas.maybe_signature')
  195. @patch('celery.app.trace.group.apply_async')
  196. def test_callbacks__only_groups(self, group_, maybe_signature):
  197. sig1 = group([Mock(name='g1'), Mock(name='g2')], app=self.app)
  198. sig2 = group([Mock(name='g3'), Mock(name='g4')], app=self.app)
  199. sig1.apply_async = Mock(name='gapply')
  200. sig2.apply_async = Mock(name='gapply')
  201. request = {'callbacks': [sig1, sig2], 'root_id': 'root'}
  202. def passt(s, *args, **kwargs):
  203. return s
  204. maybe_signature.side_effect = passt
  205. retval, _ = self.trace(self.add, (2, 2), {}, request=request)
  206. sig1.apply_async.assert_called_with(
  207. (4,), parent_id='id-1', root_id='root',
  208. )
  209. sig2.apply_async.assert_called_with(
  210. (4,), parent_id='id-1', root_id='root',
  211. )
  212. def test_trace_SystemExit(self):
  213. with pytest.raises(SystemExit):
  214. self.trace(self.raises, (SystemExit(),), {})
  215. def test_trace_Retry(self):
  216. exc = Retry('foo', 'bar')
  217. _, info = self.trace(self.raises, (exc,), {})
  218. assert info.state == states.RETRY
  219. assert info.retval is exc
  220. def test_trace_exception(self):
  221. exc = KeyError('foo')
  222. _, info = self.trace(self.raises, (exc,), {})
  223. assert info.state == states.FAILURE
  224. assert info.retval is exc
  225. def test_trace_task_ret__no_content_type(self):
  226. _trace_task_ret(
  227. self.add.name, 'id1', {}, ((2, 2), {}, {}), None, None,
  228. app=self.app,
  229. )
  230. def test_fast_trace_task__no_content_type(self):
  231. self.app.tasks[self.add.name].__trace__ = build_tracer(
  232. self.add.name, self.add, app=self.app,
  233. )
  234. _fast_trace_task(
  235. self.add.name, 'id1', {}, ((2, 2), {}, {}), None, None,
  236. app=self.app, _loc=[self.app.tasks, {}, 'hostname']
  237. )
  238. def test_trace_exception_propagate(self):
  239. with pytest.raises(KeyError):
  240. self.trace(self.raises, (KeyError('foo'),), {}, propagate=True)
  241. @patch('celery.app.trace.build_tracer')
  242. @patch('celery.app.trace.report_internal_error')
  243. def test_outside_body_error(self, report_internal_error, build_tracer):
  244. tracer = Mock()
  245. tracer.side_effect = KeyError('foo')
  246. build_tracer.return_value = tracer
  247. @self.app.task(shared=False)
  248. def xtask():
  249. pass
  250. trace_task(xtask, 'uuid', (), {})
  251. assert report_internal_error.call_count
  252. assert xtask.__trace__ is tracer
  253. class test_TraceInfo(TraceCase):
  254. class TI(TraceInfo):
  255. __slots__ = TraceInfo.__slots__ + ('__dict__',)
  256. def test_handle_error_state(self):
  257. x = self.TI(states.FAILURE)
  258. x.handle_failure = Mock()
  259. x.handle_error_state(self.add_cast, self.add_cast.request)
  260. x.handle_failure.assert_called_with(
  261. self.add_cast, self.add_cast.request,
  262. store_errors=self.add_cast.store_errors_even_if_ignored,
  263. call_errbacks=True,
  264. )
  265. @patch('celery.app.trace.ExceptionInfo')
  266. def test_handle_reject(self, ExceptionInfo):
  267. x = self.TI(states.FAILURE)
  268. x._log_error = Mock(name='log_error')
  269. req = Mock(name='req')
  270. x.handle_reject(self.add, req)
  271. x._log_error.assert_called_with(self.add, req, ExceptionInfo())
  272. class test_stackprotection:
  273. def test_stackprotection(self):
  274. setup_worker_optimizations(self.app)
  275. try:
  276. @self.app.task(shared=False, bind=True)
  277. def foo(self, i):
  278. if i:
  279. return foo(0)
  280. return self.request
  281. assert foo(1).called_directly
  282. finally:
  283. reset_worker_optimizations()