test_trace.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. from __future__ import absolute_import, unicode_literals
  2. import pytest
  3. from case import Mock, patch
  4. from celery import group, signals, states, uuid
  5. from celery.app.task import Context
  6. from celery.app.trace import (TraceInfo, _fast_trace_task, _trace_task_ret,
  7. build_tracer, get_log_policy, get_task_name,
  8. log_policy_expected, log_policy_ignore,
  9. log_policy_internal, log_policy_reject,
  10. log_policy_unexpected,
  11. reset_worker_optimizations,
  12. setup_worker_optimizations, trace_task)
  13. from celery.exceptions import Ignore, Reject, Retry
  14. from kombu.exceptions import EncodeError
  15. def trace(app, task, args=(), kwargs={},
  16. propagate=False, eager=True, request=None, **opts):
  17. t = build_tracer(task.name, task,
  18. eager=eager, propagate=propagate, app=app, **opts)
  19. ret = t('id-1', args, kwargs, request)
  20. return ret.retval, ret.info
  21. class TraceCase:
  22. def setup(self):
  23. @self.app.task(shared=False)
  24. def add(x, y):
  25. return x + y
  26. self.add = add
  27. @self.app.task(shared=False, ignore_result=True)
  28. def add_cast(x, y):
  29. return x + y
  30. self.add_cast = add_cast
  31. @self.app.task(shared=False)
  32. def raises(exc):
  33. raise exc
  34. self.raises = raises
  35. def trace(self, *args, **kwargs):
  36. return trace(self.app, *args, **kwargs)
  37. class test_trace(TraceCase):
  38. def test_trace_successful(self):
  39. retval, info = self.trace(self.add, (2, 2), {})
  40. assert info is None
  41. assert retval == 4
  42. def test_trace_on_success(self):
  43. @self.app.task(shared=False, on_success=Mock())
  44. def add_with_success(x, y):
  45. return x + y
  46. self.trace(add_with_success, (2, 2), {})
  47. add_with_success.on_success.assert_called()
  48. def test_get_log_policy(self):
  49. einfo = Mock(name='einfo')
  50. einfo.internal = False
  51. assert get_log_policy(self.add, einfo, Reject()) is log_policy_reject
  52. assert get_log_policy(self.add, einfo, Ignore()) is log_policy_ignore
  53. self.add.throws = (TypeError,)
  54. assert (get_log_policy(self.add, einfo, KeyError()) is
  55. log_policy_unexpected)
  56. assert (get_log_policy(self.add, einfo, TypeError()) is
  57. log_policy_expected)
  58. einfo2 = Mock(name='einfo2')
  59. einfo2.internal = True
  60. assert (get_log_policy(self.add, einfo2, KeyError()) is
  61. log_policy_internal)
  62. def test_get_task_name(self):
  63. assert get_task_name(Context({}), 'default') == 'default'
  64. assert get_task_name(Context({'shadow': None}), 'default') == 'default'
  65. assert get_task_name(Context({'shadow': ''}), 'default') == 'default'
  66. assert get_task_name(Context({'shadow': 'test'}), 'default') == 'test'
  67. def test_trace_after_return(self):
  68. @self.app.task(shared=False, after_return=Mock())
  69. def add_with_after_return(x, y):
  70. return x + y
  71. self.trace(add_with_after_return, (2, 2), {})
  72. add_with_after_return.after_return.assert_called()
  73. def test_with_prerun_receivers(self):
  74. on_prerun = Mock()
  75. signals.task_prerun.connect(on_prerun)
  76. try:
  77. self.trace(self.add, (2, 2), {})
  78. on_prerun.assert_called()
  79. finally:
  80. signals.task_prerun.receivers[:] = []
  81. def test_with_postrun_receivers(self):
  82. on_postrun = Mock()
  83. signals.task_postrun.connect(on_postrun)
  84. try:
  85. self.trace(self.add, (2, 2), {})
  86. on_postrun.assert_called()
  87. finally:
  88. signals.task_postrun.receivers[:] = []
  89. def test_with_success_receivers(self):
  90. on_success = Mock()
  91. signals.task_success.connect(on_success)
  92. try:
  93. self.trace(self.add, (2, 2), {})
  94. on_success.assert_called()
  95. finally:
  96. signals.task_success.receivers[:] = []
  97. def test_when_chord_part(self):
  98. @self.app.task(shared=False)
  99. def add(x, y):
  100. return x + y
  101. add.backend = Mock()
  102. request = {'chord': uuid()}
  103. self.trace(add, (2, 2), {}, request=request)
  104. add.backend.mark_as_done.assert_called()
  105. args, kwargs = add.backend.mark_as_done.call_args
  106. assert args[0] == 'id-1'
  107. assert args[1] == 4
  108. assert args[2].chord == request['chord']
  109. assert not args[3]
  110. def test_when_backend_cleanup_raises(self):
  111. @self.app.task(shared=False)
  112. def add(x, y):
  113. return x + y
  114. add.backend = Mock(name='backend')
  115. add.backend.process_cleanup.side_effect = KeyError()
  116. self.trace(add, (2, 2), {}, eager=False)
  117. add.backend.process_cleanup.assert_called_with()
  118. add.backend.process_cleanup.side_effect = MemoryError()
  119. with pytest.raises(MemoryError):
  120. self.trace(add, (2, 2), {}, eager=False)
  121. def test_when_Ignore(self):
  122. @self.app.task(shared=False)
  123. def ignored():
  124. raise Ignore()
  125. retval, info = self.trace(ignored, (), {})
  126. assert info.state == states.IGNORED
  127. def test_when_Reject(self):
  128. @self.app.task(shared=False)
  129. def rejecting():
  130. raise Reject()
  131. retval, info = self.trace(rejecting, (), {})
  132. assert info.state == states.REJECTED
  133. def test_backend_cleanup_raises(self):
  134. self.add.backend.process_cleanup = Mock()
  135. self.add.backend.process_cleanup.side_effect = RuntimeError()
  136. self.trace(self.add, (2, 2), {})
  137. @patch('celery.canvas.maybe_signature')
  138. def test_callbacks__scalar(self, maybe_signature):
  139. sig = Mock(name='sig')
  140. request = {'callbacks': [sig], 'root_id': 'root'}
  141. maybe_signature.return_value = sig
  142. retval, _ = self.trace(self.add, (2, 2), {}, request=request)
  143. sig.apply_async.assert_called_with(
  144. (4,), parent_id='id-1', root_id='root',
  145. )
  146. @patch('celery.canvas.maybe_signature')
  147. def test_chain_proto2(self, maybe_signature):
  148. sig = Mock(name='sig')
  149. sig2 = Mock(name='sig2')
  150. request = {'chain': [sig2, sig], 'root_id': 'root'}
  151. maybe_signature.return_value = sig
  152. retval, _ = self.trace(self.add, (2, 2), {}, request=request)
  153. sig.apply_async.assert_called_with(
  154. (4, ), parent_id='id-1', root_id='root',
  155. chain=[sig2],
  156. )
  157. @patch('celery.canvas.maybe_signature')
  158. def test_callbacks__EncodeError(self, maybe_signature):
  159. sig = Mock(name='sig')
  160. request = {'callbacks': [sig], 'root_id': 'root'}
  161. maybe_signature.return_value = sig
  162. sig.apply_async.side_effect = EncodeError()
  163. retval, einfo = self.trace(self.add, (2, 2), {}, request=request)
  164. assert einfo.state == states.FAILURE
  165. @patch('celery.canvas.maybe_signature')
  166. @patch('celery.app.trace.group.apply_async')
  167. def test_callbacks__sigs(self, group_, maybe_signature):
  168. sig1 = Mock(name='sig')
  169. sig2 = Mock(name='sig2')
  170. sig3 = group([Mock(name='g1'), Mock(name='g2')], app=self.app)
  171. sig3.apply_async = Mock(name='gapply')
  172. request = {'callbacks': [sig1, sig3, sig2], 'root_id': 'root'}
  173. def passt(s, *args, **kwargs):
  174. return s
  175. maybe_signature.side_effect = passt
  176. retval, _ = self.trace(self.add, (2, 2), {}, request=request)
  177. group_.assert_called_with(
  178. (4,), parent_id='id-1', root_id='root',
  179. )
  180. sig3.apply_async.assert_called_with(
  181. (4,), parent_id='id-1', root_id='root',
  182. )
  183. @patch('celery.canvas.maybe_signature')
  184. @patch('celery.app.trace.group.apply_async')
  185. def test_callbacks__only_groups(self, group_, maybe_signature):
  186. sig1 = group([Mock(name='g1'), Mock(name='g2')], app=self.app)
  187. sig2 = group([Mock(name='g3'), Mock(name='g4')], app=self.app)
  188. sig1.apply_async = Mock(name='gapply')
  189. sig2.apply_async = Mock(name='gapply')
  190. request = {'callbacks': [sig1, sig2], 'root_id': 'root'}
  191. def passt(s, *args, **kwargs):
  192. return s
  193. maybe_signature.side_effect = passt
  194. retval, _ = self.trace(self.add, (2, 2), {}, request=request)
  195. sig1.apply_async.assert_called_with(
  196. (4,), parent_id='id-1', root_id='root',
  197. )
  198. sig2.apply_async.assert_called_with(
  199. (4,), parent_id='id-1', root_id='root',
  200. )
  201. def test_trace_SystemExit(self):
  202. with pytest.raises(SystemExit):
  203. self.trace(self.raises, (SystemExit(),), {})
  204. def test_trace_Retry(self):
  205. exc = Retry('foo', 'bar')
  206. _, info = self.trace(self.raises, (exc,), {})
  207. assert info.state == states.RETRY
  208. assert info.retval is exc
  209. def test_trace_exception(self):
  210. exc = KeyError('foo')
  211. _, info = self.trace(self.raises, (exc,), {})
  212. assert info.state == states.FAILURE
  213. assert info.retval is exc
  214. def test_trace_task_ret__no_content_type(self):
  215. _trace_task_ret(
  216. self.add.name, 'id1', {}, ((2, 2), {}, {}), None, None,
  217. app=self.app,
  218. )
  219. def test_fast_trace_task__no_content_type(self):
  220. self.app.tasks[self.add.name].__trace__ = build_tracer(
  221. self.add.name, self.add, app=self.app,
  222. )
  223. _fast_trace_task(
  224. self.add.name, 'id1', {}, ((2, 2), {}, {}), None, None,
  225. app=self.app, _loc=[self.app.tasks, {}, 'hostname']
  226. )
  227. def test_trace_exception_propagate(self):
  228. with pytest.raises(KeyError):
  229. self.trace(self.raises, (KeyError('foo'),), {}, propagate=True)
  230. @patch('celery.app.trace.build_tracer')
  231. @patch('celery.app.trace.report_internal_error')
  232. def test_outside_body_error(self, report_internal_error, build_tracer):
  233. tracer = Mock()
  234. tracer.side_effect = KeyError('foo')
  235. build_tracer.return_value = tracer
  236. @self.app.task(shared=False)
  237. def xtask():
  238. pass
  239. trace_task(xtask, 'uuid', (), {})
  240. assert report_internal_error.call_count
  241. assert xtask.__trace__ is tracer
  242. class test_TraceInfo(TraceCase):
  243. class TI(TraceInfo):
  244. __slots__ = TraceInfo.__slots__ + ('__dict__',)
  245. def test_handle_error_state(self):
  246. x = self.TI(states.FAILURE)
  247. x.handle_failure = Mock()
  248. x.handle_error_state(self.add_cast, self.add_cast.request)
  249. x.handle_failure.assert_called_with(
  250. self.add_cast, self.add_cast.request,
  251. store_errors=self.add_cast.store_errors_even_if_ignored,
  252. call_errbacks=True,
  253. )
  254. @patch('celery.app.trace.ExceptionInfo')
  255. def test_handle_reject(self, ExceptionInfo):
  256. x = self.TI(states.FAILURE)
  257. x._log_error = Mock(name='log_error')
  258. req = Mock(name='req')
  259. x.handle_reject(self.add, req)
  260. x._log_error.assert_called_with(self.add, req, ExceptionInfo())
  261. class test_stackprotection:
  262. def test_stackprotection(self):
  263. setup_worker_optimizations(self.app)
  264. try:
  265. @self.app.task(shared=False, bind=True)
  266. def foo(self, i):
  267. if i:
  268. return foo(0)
  269. return self.request
  270. assert foo(1).called_directly
  271. finally:
  272. reset_worker_optimizations()