test_worker_job.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479
  1. # -*- coding: utf-8 -*-
  2. import logging
  3. import simplejson
  4. import sys
  5. import unittest2 as unittest
  6. from StringIO import StringIO
  7. from kombu.transport.base import Message
  8. from celery import states
  9. from celery.app import app_or_default
  10. from celery.datastructures import ExceptionInfo
  11. from celery.decorators import task as task_dec
  12. from celery.exceptions import RetryTaskError, NotRegistered
  13. from celery.log import setup_logger
  14. from celery.result import AsyncResult
  15. from celery.task.base import Task
  16. from celery.utils import gen_unique_id
  17. from celery.worker.job import WorkerTaskTrace, TaskRequest
  18. from celery.worker.job import execute_and_trace, AlreadyExecutedError
  19. from celery.worker.job import InvalidTaskError
  20. from celery.worker.state import revoked
  21. from celery.tests.compat import catch_warnings
  22. from celery.tests.utils import execute_context
  23. scratch = {"ACK": False}
  24. some_kwargs_scratchpad = {}
  25. def jail(task_id, task_name, args, kwargs):
  26. return WorkerTaskTrace(task_name, task_id, args, kwargs)()
  27. def on_ack():
  28. scratch["ACK"] = True
  29. @task_dec()
  30. def mytask(i, **kwargs):
  31. return i ** i
  32. @task_dec # traverses coverage for decorator without parens
  33. def mytask_no_kwargs(i):
  34. return i ** i
  35. class MyTaskIgnoreResult(Task):
  36. ignore_result = True
  37. def run(self, i):
  38. return i ** i
  39. @task_dec()
  40. def mytask_some_kwargs(i, logfile):
  41. some_kwargs_scratchpad["logfile"] = logfile
  42. return i ** i
  43. @task_dec()
  44. def mytask_raising(i, **kwargs):
  45. raise KeyError(i)
  46. class test_RetryTaskError(unittest.TestCase):
  47. def test_retry_task_error(self):
  48. try:
  49. raise Exception("foo")
  50. except Exception, exc:
  51. ret = RetryTaskError("Retrying task", exc)
  52. self.assertEqual(ret.exc, exc)
  53. class test_WorkerTaskTrace(unittest.TestCase):
  54. def test_execute_jail_success(self):
  55. ret = jail(gen_unique_id(), mytask.name, [2], {})
  56. self.assertEqual(ret, 4)
  57. def test_marked_as_started(self):
  58. mytask.track_started = True
  59. try:
  60. jail(gen_unique_id(), mytask.name, [2], {})
  61. finally:
  62. mytask.track_started = False
  63. def test_execute_jail_failure(self):
  64. ret = jail(gen_unique_id(), mytask_raising.name,
  65. [4], {})
  66. self.assertIsInstance(ret, ExceptionInfo)
  67. self.assertTupleEqual(ret.exception.args, (4, ))
  68. def test_execute_ignore_result(self):
  69. task_id = gen_unique_id()
  70. ret = jail(id, MyTaskIgnoreResult.name,
  71. [4], {})
  72. self.assertEqual(ret, 256)
  73. self.assertFalse(AsyncResult(task_id).ready())
  74. class MockEventDispatcher(object):
  75. def __init__(self):
  76. self.sent = []
  77. def send(self, event):
  78. self.sent.append(event)
  79. class test_TaskRequest(unittest.TestCase):
  80. def test_task_wrapper_repr(self):
  81. tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
  82. self.assertTrue(repr(tw))
  83. def test_send_event(self):
  84. tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
  85. tw.eventer = MockEventDispatcher()
  86. tw.send_event("task-frobulated")
  87. self.assertIn("task-frobulated", tw.eventer.sent)
  88. def test_send_email(self):
  89. app = app_or_default()
  90. old_mail_admins = app.mail_admins
  91. old_enable_mails = mytask.send_error_emails
  92. mail_sent = [False]
  93. def mock_mail_admins(*args, **kwargs):
  94. mail_sent[0] = True
  95. app.mail_admins = mock_mail_admins
  96. mytask.send_error_emails = True
  97. try:
  98. tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
  99. try:
  100. raise KeyError("moofoobar")
  101. except:
  102. einfo = ExceptionInfo(sys.exc_info())
  103. tw.on_failure(einfo)
  104. self.assertTrue(mail_sent[0])
  105. mail_sent[0] = False
  106. mytask.send_error_emails = False
  107. tw.on_failure(einfo)
  108. self.assertFalse(mail_sent[0])
  109. finally:
  110. app.mail_admins = old_mail_admins
  111. mytask.send_error_emails = old_enable_mails
  112. def test_already_revoked(self):
  113. tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
  114. tw._already_revoked = True
  115. self.assertTrue(tw.revoked())
  116. def test_revoked(self):
  117. tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
  118. revoked.add(tw.task_id)
  119. self.assertTrue(tw.revoked())
  120. self.assertTrue(tw._already_revoked)
  121. self.assertTrue(tw.acknowledged)
  122. def test_execute_does_not_execute_revoked(self):
  123. tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
  124. revoked.add(tw.task_id)
  125. tw.execute()
  126. def test_execute_acks_late(self):
  127. mytask_raising.acks_late = True
  128. tw = TaskRequest(mytask_raising.name, gen_unique_id(), [1], {"f": "x"})
  129. try:
  130. tw.execute()
  131. self.assertTrue(tw.acknowledged)
  132. finally:
  133. mytask_raising.acks_late = False
  134. def test_execute_using_pool_does_not_execute_revoked(self):
  135. tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
  136. revoked.add(tw.task_id)
  137. tw.execute_using_pool(None)
  138. def test_on_accepted_acks_early(self):
  139. tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
  140. tw.on_accepted()
  141. self.assertTrue(tw.acknowledged)
  142. def test_on_accepted_acks_late(self):
  143. tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
  144. mytask.acks_late = True
  145. try:
  146. tw.on_accepted()
  147. self.assertFalse(tw.acknowledged)
  148. finally:
  149. mytask.acks_late = False
  150. def test_on_success_acks_early(self):
  151. tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
  152. tw.time_start = 1
  153. tw.on_success(42)
  154. self.assertFalse(tw.acknowledged)
  155. def test_on_success_acks_late(self):
  156. tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
  157. tw.time_start = 1
  158. mytask.acks_late = True
  159. try:
  160. tw.on_success(42)
  161. self.assertTrue(tw.acknowledged)
  162. finally:
  163. mytask.acks_late = False
  164. def test_on_failure_acks_late(self):
  165. tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
  166. tw.time_start = 1
  167. mytask.acks_late = True
  168. try:
  169. try:
  170. raise KeyError("foo")
  171. except KeyError:
  172. exc_info = ExceptionInfo(sys.exc_info())
  173. tw.on_failure(exc_info)
  174. self.assertTrue(tw.acknowledged)
  175. finally:
  176. mytask.acks_late = False
  177. def test_from_message_invalid_kwargs(self):
  178. message_data = dict(task="foo", id=1, args=(), kwargs="foo")
  179. self.assertRaises(InvalidTaskError, TaskRequest.from_message, None,
  180. message_data)
  181. def test_on_timeout(self):
  182. class MockLogger(object):
  183. def __init__(self):
  184. self.warnings = []
  185. self.errors = []
  186. def warning(self, msg, *args, **kwargs):
  187. self.warnings.append(msg)
  188. def error(self, msg, *args, **kwargs):
  189. self.errors.append(msg)
  190. tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
  191. tw.logger = MockLogger()
  192. tw.on_timeout(soft=True)
  193. self.assertIn("Soft time limit exceeded", tw.logger.warnings[0])
  194. tw.on_timeout(soft=False)
  195. self.assertIn("Hard time limit exceeded", tw.logger.errors[0])
  196. def test_execute_and_trace(self):
  197. res = execute_and_trace(mytask.name, gen_unique_id(), [4], {})
  198. self.assertEqual(res, 4 ** 4)
  199. def test_execute_safe_catches_exception(self):
  200. old_exec = WorkerTaskTrace.execute
  201. def _error_exec(self, *args, **kwargs):
  202. raise KeyError("baz")
  203. WorkerTaskTrace.execute = _error_exec
  204. try:
  205. def with_catch_warnings(log):
  206. res = execute_and_trace(mytask.name, gen_unique_id(),
  207. [4], {})
  208. self.assertIsInstance(res, ExceptionInfo)
  209. self.assertTrue(log)
  210. self.assertIn("Exception outside", log[0].message.args[0])
  211. self.assertIn("KeyError", log[0].message.args[0])
  212. context = catch_warnings(record=True)
  213. execute_context(context, with_catch_warnings)
  214. finally:
  215. WorkerTaskTrace.execute = old_exec
  216. def create_exception(self, exc):
  217. try:
  218. raise exc
  219. except exc.__class__:
  220. return sys.exc_info()
  221. def test_worker_task_trace_handle_retry(self):
  222. from celery.exceptions import RetryTaskError
  223. uuid = gen_unique_id()
  224. w = WorkerTaskTrace(mytask.name, uuid, [4], {})
  225. type_, value_, tb_ = self.create_exception(ValueError("foo"))
  226. type_, value_, tb_ = self.create_exception(RetryTaskError(str(value_),
  227. exc=value_))
  228. w._store_errors = False
  229. w.handle_retry(value_, type_, tb_, "")
  230. self.assertEqual(mytask.backend.get_status(uuid), states.PENDING)
  231. w._store_errors = True
  232. w.handle_retry(value_, type_, tb_, "")
  233. self.assertEqual(mytask.backend.get_status(uuid), states.RETRY)
  234. def test_worker_task_trace_handle_failure(self):
  235. uuid = gen_unique_id()
  236. w = WorkerTaskTrace(mytask.name, uuid, [4], {})
  237. type_, value_, tb_ = self.create_exception(ValueError("foo"))
  238. w._store_errors = False
  239. w.handle_failure(value_, type_, tb_, "")
  240. self.assertEqual(mytask.backend.get_status(uuid), states.PENDING)
  241. w._store_errors = True
  242. w.handle_failure(value_, type_, tb_, "")
  243. self.assertEqual(mytask.backend.get_status(uuid), states.FAILURE)
  244. def test_executed_bit(self):
  245. tw = TaskRequest(mytask.name, gen_unique_id(), [], {})
  246. self.assertFalse(tw.executed)
  247. tw._set_executed_bit()
  248. self.assertTrue(tw.executed)
  249. self.assertRaises(AlreadyExecutedError, tw._set_executed_bit)
  250. def test_task_wrapper_mail_attrs(self):
  251. tw = TaskRequest(mytask.name, gen_unique_id(), [], {})
  252. x = tw.success_msg % {"name": tw.task_name,
  253. "id": tw.task_id,
  254. "return_value": 10}
  255. self.assertTrue(x)
  256. x = tw.error_msg % {"name": tw.task_name,
  257. "id": tw.task_id,
  258. "exc": "FOOBARBAZ",
  259. "traceback": "foobarbaz"}
  260. self.assertTrue(x)
  261. x = tw.email_subject % {"name": tw.task_name,
  262. "id": tw.task_id,
  263. "exc": "FOOBARBAZ",
  264. "hostname": "lana"}
  265. self.assertTrue(x)
  266. def test_from_message(self):
  267. body = {"task": mytask.name, "id": gen_unique_id(),
  268. "args": [2], "kwargs": {u"æØåveéðƒeæ": "bar"}}
  269. m = Message(None, body=simplejson.dumps(body), backend="foo",
  270. content_type="application/json",
  271. content_encoding="utf-8")
  272. tw = TaskRequest.from_message(m, m.decode())
  273. self.assertIsInstance(tw, TaskRequest)
  274. self.assertEqual(tw.task_name, body["task"])
  275. self.assertEqual(tw.task_id, body["id"])
  276. self.assertEqual(tw.args, body["args"])
  277. self.assertEqual(tw.kwargs.keys()[0],
  278. u"æØåveéðƒeæ".encode("utf-8"))
  279. self.assertNotIsInstance(tw.kwargs.keys()[0], unicode)
  280. self.assertTrue(tw.logger)
  281. def test_from_message_nonexistant_task(self):
  282. body = {"task": "cu.mytask.doesnotexist", "id": gen_unique_id(),
  283. "args": [2], "kwargs": {u"æØåveéðƒeæ": "bar"}}
  284. m = Message(None, body=simplejson.dumps(body), backend="foo",
  285. content_type="application/json",
  286. content_encoding="utf-8")
  287. self.assertRaises(NotRegistered, TaskRequest.from_message,
  288. m, m.decode())
  289. def test_execute(self):
  290. tid = gen_unique_id()
  291. tw = TaskRequest(mytask.name, tid, [4], {"f": "x"})
  292. self.assertEqual(tw.execute(), 256)
  293. meta = mytask.backend.get_task_meta(tid)
  294. self.assertEqual(meta["result"], 256)
  295. self.assertEqual(meta["status"], states.SUCCESS)
  296. def test_execute_success_no_kwargs(self):
  297. tid = gen_unique_id()
  298. tw = TaskRequest(mytask_no_kwargs.name, tid, [4], {})
  299. self.assertEqual(tw.execute(), 256)
  300. meta = mytask_no_kwargs.backend.get_task_meta(tid)
  301. self.assertEqual(meta["result"], 256)
  302. self.assertEqual(meta["status"], states.SUCCESS)
  303. def test_execute_success_some_kwargs(self):
  304. tid = gen_unique_id()
  305. tw = TaskRequest(mytask_some_kwargs.name, tid, [4], {})
  306. self.assertEqual(tw.execute(logfile="foobaz.log"), 256)
  307. meta = mytask_some_kwargs.backend.get_task_meta(tid)
  308. self.assertEqual(some_kwargs_scratchpad.get("logfile"), "foobaz.log")
  309. self.assertEqual(meta["result"], 256)
  310. self.assertEqual(meta["status"], states.SUCCESS)
  311. def test_execute_ack(self):
  312. tid = gen_unique_id()
  313. tw = TaskRequest(mytask.name, tid, [4], {"f": "x"},
  314. on_ack=on_ack)
  315. self.assertEqual(tw.execute(), 256)
  316. meta = mytask.backend.get_task_meta(tid)
  317. self.assertTrue(scratch["ACK"])
  318. self.assertEqual(meta["result"], 256)
  319. self.assertEqual(meta["status"], states.SUCCESS)
  320. def test_execute_fail(self):
  321. tid = gen_unique_id()
  322. tw = TaskRequest(mytask_raising.name, tid, [4], {"f": "x"})
  323. self.assertIsInstance(tw.execute(), ExceptionInfo)
  324. meta = mytask_raising.backend.get_task_meta(tid)
  325. self.assertEqual(meta["status"], states.FAILURE)
  326. self.assertIsInstance(meta["result"], KeyError)
  327. def test_execute_using_pool(self):
  328. tid = gen_unique_id()
  329. tw = TaskRequest(mytask.name, tid, [4], {"f": "x"})
  330. class MockPool(object):
  331. target = None
  332. args = None
  333. kwargs = None
  334. def __init__(self, *args, **kwargs):
  335. pass
  336. def apply_async(self, target, args=None, kwargs=None,
  337. *margs, **mkwargs):
  338. self.target = target
  339. self.args = args
  340. self.kwargs = kwargs
  341. p = MockPool()
  342. tw.execute_using_pool(p)
  343. self.assertTrue(p.target)
  344. self.assertEqual(p.args[0], mytask.name)
  345. self.assertEqual(p.args[1], tid)
  346. self.assertEqual(p.args[2], [4])
  347. self.assertIn("f", p.args[3])
  348. self.assertIn([4], p.args)
  349. def test_default_kwargs(self):
  350. tid = gen_unique_id()
  351. tw = TaskRequest(mytask.name, tid, [4], {"f": "x"})
  352. self.assertDictEqual(
  353. tw.extend_with_default_kwargs(10, "some_logfile"), {
  354. "f": "x",
  355. "logfile": "some_logfile",
  356. "loglevel": 10,
  357. "task_id": tw.task_id,
  358. "task_retries": 0,
  359. "task_is_eager": False,
  360. "delivery_info": {},
  361. "task_name": tw.task_name})
  362. def _test_on_failure(self, exception):
  363. app = app_or_default()
  364. tid = gen_unique_id()
  365. tw = TaskRequest(mytask.name, tid, [4], {"f": "x"})
  366. try:
  367. raise exception
  368. except Exception:
  369. exc_info = ExceptionInfo(sys.exc_info())
  370. logfh = StringIO()
  371. tw.logger.handlers = []
  372. tw.logger = setup_logger(logfile=logfh, loglevel=logging.INFO,
  373. root=False)
  374. app.conf.CELERY_SEND_TASK_ERROR_EMAILS = True
  375. tw.on_failure(exc_info)
  376. logvalue = logfh.getvalue()
  377. self.assertIn(mytask.name, logvalue)
  378. self.assertIn(tid, logvalue)
  379. self.assertIn("ERROR", logvalue)
  380. app.conf.CELERY_SEND_TASK_ERROR_EMAILS = False
  381. def test_on_failure(self):
  382. self._test_on_failure(Exception("Inside unit tests"))
  383. def test_on_failure_unicode_exception(self):
  384. self._test_on_failure(Exception(u"Бобры атакуют"))
  385. def test_on_failure_utf8_exception(self):
  386. self._test_on_failure(Exception(
  387. u"Бобры атакуют".encode('utf8')))