test_trace.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. from __future__ import absolute_import, unicode_literals
  2. import pytest
  3. from case import Mock, patch
  4. from kombu.exceptions import EncodeError
  5. from celery import group, signals, states, uuid
  6. from celery.app.task import Context
  7. from celery.app.trace import (TraceInfo, _fast_trace_task, _trace_task_ret,
  8. build_tracer, get_log_policy, get_task_name,
  9. log_policy_expected, log_policy_ignore,
  10. log_policy_internal, log_policy_reject,
  11. log_policy_unexpected,
  12. reset_worker_optimizations,
  13. setup_worker_optimizations, trace_task)
  14. from celery.exceptions import Ignore, Reject, Retry
  15. def trace(app, task, args=(), kwargs={},
  16. propagate=False, eager=True, request=None, **opts):
  17. t = build_tracer(task.name, task,
  18. eager=eager, propagate=propagate, app=app, **opts)
  19. ret = t('id-1', args, kwargs, request)
  20. return ret.retval, ret.info
  21. class TraceCase:
  22. def setup(self):
  23. @self.app.task(shared=False)
  24. def add(x, y):
  25. return x + y
  26. self.add = add
  27. @self.app.task(shared=False, ignore_result=True)
  28. def add_cast(x, y):
  29. return x + y
  30. self.add_cast = add_cast
  31. @self.app.task(shared=False)
  32. def raises(exc):
  33. raise exc
  34. self.raises = raises
  35. def trace(self, *args, **kwargs):
  36. return trace(self.app, *args, **kwargs)
  37. class test_trace(TraceCase):
  38. def test_trace_successful(self):
  39. retval, info = self.trace(self.add, (2, 2), {})
  40. assert info is None
  41. assert retval == 4
  42. def test_trace_on_success(self):
  43. @self.app.task(shared=False, on_success=Mock())
  44. def add_with_success(x, y):
  45. return x + y
  46. self.trace(add_with_success, (2, 2), {})
  47. add_with_success.on_success.assert_called()
  48. def test_get_log_policy(self):
  49. einfo = Mock(name='einfo')
  50. einfo.internal = False
  51. assert get_log_policy(self.add, einfo, Reject()) is log_policy_reject
  52. assert get_log_policy(self.add, einfo, Ignore()) is log_policy_ignore
  53. self.add.throws = (TypeError,)
  54. assert (get_log_policy(self.add, einfo, KeyError()) is
  55. log_policy_unexpected)
  56. assert (get_log_policy(self.add, einfo, TypeError()) is
  57. log_policy_expected)
  58. einfo2 = Mock(name='einfo2')
  59. einfo2.internal = True
  60. assert (get_log_policy(self.add, einfo2, KeyError()) is
  61. log_policy_internal)
  62. def test_get_task_name(self):
  63. assert get_task_name(Context({}), 'default') == 'default'
  64. assert get_task_name(Context({'shadow': None}), 'default') == 'default'
  65. assert get_task_name(Context({'shadow': ''}), 'default') == 'default'
  66. assert get_task_name(Context({'shadow': 'test'}), 'default') == 'test'
  67. def test_trace_after_return(self):
  68. @self.app.task(shared=False, after_return=Mock())
  69. def add_with_after_return(x, y):
  70. return x + y
  71. self.trace(add_with_after_return, (2, 2), {})
  72. add_with_after_return.after_return.assert_called()
  73. def test_with_prerun_receivers(self):
  74. on_prerun = Mock()
  75. signals.task_prerun.connect(on_prerun)
  76. try:
  77. self.trace(self.add, (2, 2), {})
  78. on_prerun.assert_called()
  79. finally:
  80. signals.task_prerun.receivers[:] = []
  81. def test_with_postrun_receivers(self):
  82. on_postrun = Mock()
  83. signals.task_postrun.connect(on_postrun)
  84. try:
  85. self.trace(self.add, (2, 2), {})
  86. on_postrun.assert_called()
  87. finally:
  88. signals.task_postrun.receivers[:] = []
  89. def test_with_success_receivers(self):
  90. on_success = Mock()
  91. signals.task_success.connect(on_success)
  92. try:
  93. self.trace(self.add, (2, 2), {})
  94. on_success.assert_called()
  95. finally:
  96. signals.task_success.receivers[:] = []
  97. def test_when_chord_part(self):
  98. @self.app.task(shared=False)
  99. def add(x, y):
  100. return x + y
  101. add.backend = Mock()
  102. request = {'chord': uuid()}
  103. self.trace(add, (2, 2), {}, request=request)
  104. add.backend.mark_as_done.assert_called()
  105. args, kwargs = add.backend.mark_as_done.call_args
  106. assert args[0] == 'id-1'
  107. assert args[1] == 4
  108. assert args[2].chord == request['chord']
  109. assert not args[3]
  110. def test_when_backend_cleanup_raises(self):
  111. @self.app.task(shared=False)
  112. def add(x, y):
  113. return x + y
  114. add.backend = Mock(name='backend')
  115. add.backend.process_cleanup.side_effect = KeyError()
  116. self.trace(add, (2, 2), {}, eager=False)
  117. add.backend.process_cleanup.assert_called_with()
  118. add.backend.process_cleanup.side_effect = MemoryError()
  119. with pytest.raises(MemoryError):
  120. self.trace(add, (2, 2), {}, eager=False)
  121. def test_when_Ignore(self):
  122. @self.app.task(shared=False)
  123. def ignored():
  124. raise Ignore()
  125. retval, info = self.trace(ignored, (), {})
  126. assert info.state == states.IGNORED
  127. def test_when_Reject(self):
  128. @self.app.task(shared=False)
  129. def rejecting():
  130. raise Reject()
  131. retval, info = self.trace(rejecting, (), {})
  132. assert info.state == states.REJECTED
  133. def test_backend_cleanup_raises(self):
  134. self.add.backend.process_cleanup = Mock()
  135. self.add.backend.process_cleanup.side_effect = RuntimeError()
  136. self.trace(self.add, (2, 2), {})
  137. @patch('celery.canvas.maybe_signature')
  138. def test_callbacks__scalar(self, maybe_signature):
  139. sig = Mock(name='sig')
  140. request = {'callbacks': [sig], 'root_id': 'root'}
  141. maybe_signature.return_value = sig
  142. retval, _ = self.trace(self.add, (2, 2), {}, request=request)
  143. sig.apply_async.assert_called_with(
  144. (4,), parent_id='id-1', root_id='root',
  145. )
  146. @patch('celery.canvas.maybe_signature')
  147. def test_chain_proto2(self, maybe_signature):
  148. sig = Mock(name='sig')
  149. sig2 = Mock(name='sig2')
  150. request = {'chain': [sig2, sig], 'root_id': 'root'}
  151. maybe_signature.return_value = sig
  152. retval, _ = self.trace(self.add, (2, 2), {}, request=request)
  153. sig.apply_async.assert_called_with(
  154. (4, ), parent_id='id-1', root_id='root',
  155. chain=[sig2],
  156. )
  157. @patch('celery.canvas.maybe_signature')
  158. def test_callbacks__EncodeError(self, maybe_signature):
  159. sig = Mock(name='sig')
  160. request = {'callbacks': [sig], 'root_id': 'root'}
  161. maybe_signature.return_value = sig
  162. sig.apply_async.side_effect = EncodeError()
  163. retval, einfo = self.trace(self.add, (2, 2), {}, request=request)
  164. assert einfo.state == states.FAILURE
  165. @patch('celery.canvas.maybe_signature')
  166. @patch('celery.app.trace.group.apply_async')
  167. def test_callbacks__sigs(self, group_, maybe_signature):
  168. sig1 = Mock(name='sig')
  169. sig2 = Mock(name='sig2')
  170. sig3 = group([Mock(name='g1'), Mock(name='g2')], app=self.app)
  171. sig3.apply_async = Mock(name='gapply')
  172. request = {'callbacks': [sig1, sig3, sig2], 'root_id': 'root'}
  173. def passt(s, *args, **kwargs):
  174. return s
  175. maybe_signature.side_effect = passt
  176. retval, _ = self.trace(self.add, (2, 2), {}, request=request)
  177. group_.assert_called_with(
  178. (4,), parent_id='id-1', root_id='root',
  179. )
  180. sig3.apply_async.assert_called_with(
  181. (4,), parent_id='id-1', root_id='root',
  182. )
  183. @patch('celery.canvas.maybe_signature')
  184. @patch('celery.app.trace.group.apply_async')
  185. def test_callbacks__only_groups(self, group_, maybe_signature):
  186. sig1 = group([Mock(name='g1'), Mock(name='g2')], app=self.app)
  187. sig2 = group([Mock(name='g3'), Mock(name='g4')], app=self.app)
  188. sig1.apply_async = Mock(name='gapply')
  189. sig2.apply_async = Mock(name='gapply')
  190. request = {'callbacks': [sig1, sig2], 'root_id': 'root'}
  191. def passt(s, *args, **kwargs):
  192. return s
  193. maybe_signature.side_effect = passt
  194. retval, _ = self.trace(self.add, (2, 2), {}, request=request)
  195. sig1.apply_async.assert_called_with(
  196. (4,), parent_id='id-1', root_id='root',
  197. )
  198. sig2.apply_async.assert_called_with(
  199. (4,), parent_id='id-1', root_id='root',
  200. )
  201. def test_trace_SystemExit(self):
  202. with pytest.raises(SystemExit):
  203. self.trace(self.raises, (SystemExit(),), {})
  204. def test_trace_Retry(self):
  205. exc = Retry('foo', 'bar')
  206. _, info = self.trace(self.raises, (exc,), {})
  207. assert info.state == states.RETRY
  208. assert info.retval is exc
  209. def test_trace_exception(self):
  210. exc = KeyError('foo')
  211. _, info = self.trace(self.raises, (exc,), {})
  212. assert info.state == states.FAILURE
  213. assert info.retval is exc
  214. def test_trace_task_ret__no_content_type(self):
  215. _trace_task_ret(
  216. self.add.name, 'id1', {}, ((2, 2), {}, {}), None, None,
  217. app=self.app,
  218. )
  219. def test_fast_trace_task__no_content_type(self):
  220. self.app.tasks[self.add.name].__trace__ = build_tracer(
  221. self.add.name, self.add, app=self.app,
  222. )
  223. _fast_trace_task(
  224. self.add.name, 'id1', {}, ((2, 2), {}, {}), None, None,
  225. app=self.app, _loc=[self.app.tasks, {}, 'hostname']
  226. )
  227. def test_trace_exception_propagate(self):
  228. with pytest.raises(KeyError):
  229. self.trace(self.raises, (KeyError('foo'),), {}, propagate=True)
  230. @patch('celery.app.trace.build_tracer')
  231. @patch('celery.app.trace.report_internal_error')
  232. def test_outside_body_error(self, report_internal_error, build_tracer):
  233. tracer = Mock()
  234. tracer.side_effect = KeyError('foo')
  235. build_tracer.return_value = tracer
  236. @self.app.task(shared=False)
  237. def xtask():
  238. pass
  239. trace_task(xtask, 'uuid', (), {})
  240. assert report_internal_error.call_count
  241. assert xtask.__trace__ is tracer
  242. class test_TraceInfo(TraceCase):
  243. class TI(TraceInfo):
  244. __slots__ = TraceInfo.__slots__ + ('__dict__',)
  245. def test_handle_error_state(self):
  246. x = self.TI(states.FAILURE)
  247. x.handle_failure = Mock()
  248. x.handle_error_state(self.add_cast, self.add_cast.request)
  249. x.handle_failure.assert_called_with(
  250. self.add_cast, self.add_cast.request,
  251. store_errors=self.add_cast.store_errors_even_if_ignored,
  252. call_errbacks=True,
  253. )
  254. @patch('celery.app.trace.ExceptionInfo')
  255. def test_handle_reject(self, ExceptionInfo):
  256. x = self.TI(states.FAILURE)
  257. x._log_error = Mock(name='log_error')
  258. req = Mock(name='req')
  259. x.handle_reject(self.add, req)
  260. x._log_error.assert_called_with(self.add, req, ExceptionInfo())
  261. class test_stackprotection:
  262. def test_stackprotection(self):
  263. setup_worker_optimizations(self.app)
  264. try:
  265. @self.app.task(shared=False, bind=True)
  266. def foo(self, i):
  267. if i:
  268. return foo(0)
  269. return self.request
  270. assert foo(1).called_directly
  271. finally:
  272. reset_worker_optimizations()