test_worker_job.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. # -*- coding: utf-8 -*-
  2. import sys
  3. import unittest
  4. from celery.worker.job import WorkerTaskTrace, TaskWrapper
  5. from celery.datastructures import ExceptionInfo
  6. from celery.models import TaskMeta
  7. from celery.registry import tasks, NotRegistered
  8. from celery.pool import TaskPool
  9. from celery.utils import gen_unique_id
  10. from carrot.backends.base import BaseMessage
  11. from StringIO import StringIO
  12. from celery.log import setup_logger
  13. from django.core import cache
  14. from celery.decorators import task as task_dec
  15. import simplejson
  16. import logging
  17. scratch = {"ACK": False}
  18. some_kwargs_scratchpad = {}
  19. def jail(task_id, task_name, args, kwargs):
  20. return WorkerTaskTrace(task_name, task_id, args, kwargs)()
  21. def on_ack():
  22. scratch["ACK"] = True
  23. @task_dec()
  24. def mytask(i, **kwargs):
  25. return i ** i
  26. @task_dec()
  27. def mytask_no_kwargs(i):
  28. return i ** i
  29. @task_dec()
  30. def mytask_some_kwargs(i, logfile):
  31. some_kwargs_scratchpad["logfile"] = logfile
  32. return i ** i
  33. @task_dec()
  34. def mytask_raising(i, **kwargs):
  35. raise KeyError(i)
  36. @task_dec()
  37. def get_db_connection(i, **kwargs):
  38. from django.db import connection
  39. return id(connection)
  40. get_db_connection.ignore_result = True
  41. class TestJail(unittest.TestCase):
  42. def test_execute_jail_success(self):
  43. ret = jail(gen_unique_id(), mytask.name, [2], {})
  44. self.assertEquals(ret, 4)
  45. def test_execute_jail_failure(self):
  46. ret = jail(gen_unique_id(), mytask_raising.name,
  47. [4], {})
  48. self.assertTrue(isinstance(ret, ExceptionInfo))
  49. self.assertEquals(ret.exception.args, (4, ))
  50. def test_django_db_connection_is_closed(self):
  51. from django.db import connection
  52. connection._was_closed = False
  53. old_connection_close = connection.close
  54. def monkeypatched_connection_close(*args, **kwargs):
  55. connection._was_closed = True
  56. return old_connection_close(*args, **kwargs)
  57. connection.close = monkeypatched_connection_close
  58. ret = jail(gen_unique_id(),
  59. get_db_connection.name, [2], {})
  60. self.assertTrue(connection._was_closed)
  61. connection.close = old_connection_close
  62. def test_django_cache_connection_is_closed(self):
  63. old_cache_close = getattr(cache.cache, "close", None)
  64. old_backend = cache.settings.CACHE_BACKEND
  65. cache.settings.CACHE_BACKEND = "libmemcached"
  66. cache._was_closed = False
  67. old_cache_parse_backend = getattr(cache, "parse_backend_uri", None)
  68. if old_cache_parse_backend: # checks to make sure attr exists
  69. delattr(cache, 'parse_backend_uri')
  70. def monkeypatched_cache_close(*args, **kwargs):
  71. cache._was_closed = True
  72. cache.cache.close = monkeypatched_cache_close
  73. jail(gen_unique_id(), mytask.name, [4], {})
  74. self.assertTrue(cache._was_closed)
  75. cache.cache.close = old_cache_close
  76. cache.settings.CACHE_BACKEND = old_backend
  77. if old_cache_parse_backend:
  78. cache.parse_backend_uri = old_cache_parse_backend
  79. def test_django_cache_connection_is_closed_django_1_1(self):
  80. old_cache_close = getattr(cache.cache, "close", None)
  81. old_backend = cache.settings.CACHE_BACKEND
  82. cache.settings.CACHE_BACKEND = "libmemcached"
  83. cache._was_closed = False
  84. old_cache_parse_backend = getattr(cache, "parse_backend_uri", None)
  85. cache.parse_backend_uri = lambda uri: ["libmemcached", "1", "2"]
  86. def monkeypatched_cache_close(*args, **kwargs):
  87. cache._was_closed = True
  88. cache.cache.close = monkeypatched_cache_close
  89. jail(gen_unique_id(), mytask.name, [4], {})
  90. self.assertTrue(cache._was_closed)
  91. cache.cache.close = old_cache_close
  92. cache.settings.CACHE_BACKEND = old_backend
  93. if old_cache_parse_backend:
  94. cache.parse_backend_uri = old_cache_parse_backend
  95. else:
  96. del(cache.parse_backend_uri)
  97. class TestTaskWrapper(unittest.TestCase):
  98. def test_task_wrapper_repr(self):
  99. tw = TaskWrapper(mytask.name, gen_unique_id(), [1], {"f": "x"})
  100. self.assertTrue(repr(tw))
  101. def test_task_wrapper_mail_attrs(self):
  102. tw = TaskWrapper(mytask.name, gen_unique_id(), [], {})
  103. x = tw.success_msg % {"name": tw.task_name,
  104. "id": tw.task_id,
  105. "return_value": 10}
  106. self.assertTrue(x)
  107. x = tw.fail_msg % {"name": tw.task_name,
  108. "id": tw.task_id,
  109. "exc": "FOOBARBAZ",
  110. "traceback": "foobarbaz"}
  111. self.assertTrue(x)
  112. x = tw.fail_email_subject % {"name": tw.task_name,
  113. "id": tw.task_id,
  114. "exc": "FOOBARBAZ",
  115. "hostname": "lana"}
  116. self.assertTrue(x)
  117. def test_from_message(self):
  118. body = {"task": mytask.name, "id": gen_unique_id(),
  119. "args": [2], "kwargs": {u"æØåveéðƒeæ": "bar"}}
  120. m = BaseMessage(body=simplejson.dumps(body), backend="foo",
  121. content_type="application/json",
  122. content_encoding="utf-8")
  123. tw = TaskWrapper.from_message(m, m.decode())
  124. self.assertTrue(isinstance(tw, TaskWrapper))
  125. self.assertEquals(tw.task_name, body["task"])
  126. self.assertEquals(tw.task_id, body["id"])
  127. self.assertEquals(tw.args, body["args"])
  128. self.assertEquals(tw.kwargs.keys()[0],
  129. u"æØåveéðƒeæ".encode("utf-8"))
  130. self.assertFalse(isinstance(tw.kwargs.keys()[0], unicode))
  131. self.assertTrue(tw.logger)
  132. def test_from_message_nonexistant_task(self):
  133. body = {"task": "cu.mytask.doesnotexist", "id": gen_unique_id(),
  134. "args": [2], "kwargs": {u"æØåveéðƒeæ": "bar"}}
  135. m = BaseMessage(body=simplejson.dumps(body), backend="foo",
  136. content_type="application/json",
  137. content_encoding="utf-8")
  138. self.assertRaises(NotRegistered, TaskWrapper.from_message,
  139. m, m.decode())
  140. def test_execute(self):
  141. tid = gen_unique_id()
  142. tw = TaskWrapper(mytask.name, tid, [4], {"f": "x"})
  143. self.assertEquals(tw.execute(), 256)
  144. meta = TaskMeta.objects.get(task_id=tid)
  145. self.assertEquals(meta.result, 256)
  146. self.assertEquals(meta.status, "DONE")
  147. def test_execute_success_no_kwargs(self):
  148. tid = gen_unique_id()
  149. tw = TaskWrapper(mytask_no_kwargs.name, tid, [4], {})
  150. self.assertEquals(tw.execute(), 256)
  151. meta = TaskMeta.objects.get(task_id=tid)
  152. self.assertEquals(meta.result, 256)
  153. self.assertEquals(meta.status, "DONE")
  154. def test_execute_success_some_kwargs(self):
  155. tid = gen_unique_id()
  156. tw = TaskWrapper(mytask_some_kwargs.name, tid, [4], {})
  157. self.assertEquals(tw.execute(logfile="foobaz.log"), 256)
  158. meta = TaskMeta.objects.get(task_id=tid)
  159. self.assertEquals(some_kwargs_scratchpad.get("logfile"), "foobaz.log")
  160. self.assertEquals(meta.result, 256)
  161. self.assertEquals(meta.status, "DONE")
  162. def test_execute_ack(self):
  163. tid = gen_unique_id()
  164. tw = TaskWrapper(mytask.name, tid, [4], {"f": "x"},
  165. on_ack=on_ack)
  166. self.assertEquals(tw.execute(), 256)
  167. meta = TaskMeta.objects.get(task_id=tid)
  168. self.assertTrue(scratch["ACK"])
  169. self.assertEquals(meta.result, 256)
  170. self.assertEquals(meta.status, "DONE")
  171. def test_execute_fail(self):
  172. tid = gen_unique_id()
  173. tw = TaskWrapper(mytask_raising.name, tid, [4], {"f": "x"})
  174. self.assertTrue(isinstance(tw.execute(), ExceptionInfo))
  175. meta = TaskMeta.objects.get(task_id=tid)
  176. self.assertEquals(meta.status, "FAILURE")
  177. self.assertTrue(isinstance(meta.result, KeyError))
  178. def test_execute_using_pool(self):
  179. tid = gen_unique_id()
  180. tw = TaskWrapper(mytask.name, tid, [4], {"f": "x"})
  181. p = TaskPool(2)
  182. p.start()
  183. asyncres = tw.execute_using_pool(p)
  184. self.assertTrue(asyncres.get(), 256)
  185. p.stop()
  186. def test_default_kwargs(self):
  187. tid = gen_unique_id()
  188. tw = TaskWrapper(mytask.name, tid, [4], {"f": "x"})
  189. self.assertEquals(tw.extend_with_default_kwargs(10, "some_logfile"), {
  190. "f": "x",
  191. "logfile": "some_logfile",
  192. "loglevel": 10,
  193. "task_id": tw.task_id,
  194. "task_retries": 0,
  195. "task_name": tw.task_name})
  196. def test_on_failure(self):
  197. tid = gen_unique_id()
  198. tw = TaskWrapper(mytask.name, tid, [4], {"f": "x"})
  199. try:
  200. raise Exception("Inside unit tests")
  201. except Exception:
  202. exc_info = ExceptionInfo(sys.exc_info())
  203. logfh = StringIO()
  204. tw.logger.handlers = []
  205. tw.logger = setup_logger(logfile=logfh, loglevel=logging.INFO)
  206. from celery import conf
  207. conf.SEND_CELERY_TASK_ERROR_EMAILS = True
  208. tw.on_failure(exc_info)
  209. logvalue = logfh.getvalue()
  210. self.assertTrue(mytask.name in logvalue)
  211. self.assertTrue(tid in logvalue)
  212. self.assertTrue("ERROR" in logvalue)
  213. conf.SEND_CELERY_TASK_ERROR_EMAILS = False