test_worker_job.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. # -*- coding: utf-8 -*-
  2. import sys
  3. import unittest
  4. from celery.worker.job import jail
  5. from celery.worker.job import TaskWrapper
  6. from celery.datastructures import ExceptionInfo
  7. from celery.models import TaskMeta
  8. from celery.registry import tasks, NotRegistered
  9. from celery.pool import TaskPool
  10. from celery.utils import gen_unique_id
  11. from carrot.backends.base import BaseMessage
  12. from StringIO import StringIO
  13. from celery.log import setup_logger
  14. from django.core import cache
  15. import simplejson
  16. import logging
  17. scratch = {"ACK": False}
  18. def on_ack():
  19. scratch["ACK"] = True
  20. def mytask(i, **kwargs):
  21. return i ** i
  22. tasks.register(mytask, name="cu.mytask")
  23. def mytask_raising(i, **kwargs):
  24. raise KeyError(i)
  25. tasks.register(mytask_raising, name="cu.mytask-raising")
  26. def get_db_connection(i, **kwargs):
  27. from django.db import connection
  28. return id(connection)
  29. get_db_connection.ignore_result = True
  30. class TestJail(unittest.TestCase):
  31. def test_execute_jail_success(self):
  32. ret = jail(gen_unique_id(), gen_unique_id(), mytask, [2], {})
  33. self.assertEquals(ret, 4)
  34. def test_execute_jail_failure(self):
  35. ret = jail(gen_unique_id(), gen_unique_id(), mytask_raising, [4], {})
  36. self.assertTrue(isinstance(ret, ExceptionInfo))
  37. self.assertEquals(ret.exception.args, (4, ))
  38. def test_django_db_connection_is_closed(self):
  39. from django.db import connection
  40. connection._was_closed = False
  41. old_connection_close = connection.close
  42. def monkeypatched_connection_close(*args, **kwargs):
  43. connection._was_closed = True
  44. return old_connection_close(*args, **kwargs)
  45. connection.close = monkeypatched_connection_close
  46. ret = jail(gen_unique_id(), gen_unique_id(),
  47. get_db_connection, [2], {})
  48. self.assertTrue(connection._was_closed)
  49. connection.close = old_connection_close
  50. def test_django_cache_connection_is_closed(self):
  51. old_cache_close = getattr(cache.cache, "close", None)
  52. old_backend = cache.settings.CACHE_BACKEND
  53. cache.settings.CACHE_BACKEND = "libmemcached"
  54. cache._was_closed = False
  55. old_cache_parse_backend = getattr(cache, "parse_backend_uri", None)
  56. if old_cache_parse_backend: # checks to make sure attr exists
  57. delattr(cache, 'parse_backend_uri')
  58. def monkeypatched_cache_close(*args, **kwargs):
  59. cache._was_closed = True
  60. cache.cache.close = monkeypatched_cache_close
  61. jail(gen_unique_id(), gen_unique_id(), mytask, [4], {})
  62. self.assertTrue(cache._was_closed)
  63. cache.cache.close = old_cache_close
  64. cache.settings.CACHE_BACKEND = old_backend
  65. if old_cache_parse_backend:
  66. cache.parse_backend_uri = old_cache_parse_backend
  67. def test_django_cache_connection_is_closed_django_1_1(self):
  68. old_cache_close = getattr(cache.cache, "close", None)
  69. old_backend = cache.settings.CACHE_BACKEND
  70. cache.settings.CACHE_BACKEND = "libmemcached"
  71. cache._was_closed = False
  72. old_cache_parse_backend = getattr(cache, "parse_backend_uri", None)
  73. cache.parse_backend_uri = lambda uri: ["libmemcached", "1", "2"]
  74. def monkeypatched_cache_close(*args, **kwargs):
  75. cache._was_closed = True
  76. cache.cache.close = monkeypatched_cache_close
  77. jail(gen_unique_id(), gen_unique_id(), mytask, [4], {})
  78. self.assertTrue(cache._was_closed)
  79. cache.cache.close = old_cache_close
  80. cache.settings.CACHE_BACKEND = old_backend
  81. if old_cache_parse_backend:
  82. cache.parse_backend_uri = old_cache_parse_backend
  83. else:
  84. del(cache.parse_backend_uri)
  85. class TestTaskWrapper(unittest.TestCase):
  86. def test_task_wrapper_attrs(self):
  87. tw = TaskWrapper(gen_unique_id(), gen_unique_id(),
  88. mytask, [1], {"f": "x"})
  89. for attr in ("task_name", "task_id", "args", "kwargs", "logger"):
  90. self.assertTrue(getattr(tw, attr, None))
  91. def test_task_wrapper_repr(self):
  92. tw = TaskWrapper(gen_unique_id(), gen_unique_id(),
  93. mytask, [1], {"f": "x"})
  94. self.assertTrue(repr(tw))
  95. def test_task_wrapper_mail_attrs(self):
  96. tw = TaskWrapper(gen_unique_id(), gen_unique_id(), mytask, [], {})
  97. x = tw.success_msg % {"name": tw.task_name,
  98. "id": tw.task_id,
  99. "return_value": 10}
  100. self.assertTrue(x)
  101. x = tw.fail_msg % {"name": tw.task_name,
  102. "id": tw.task_id,
  103. "exc": "FOOBARBAZ",
  104. "traceback": "foobarbaz"}
  105. self.assertTrue(x)
  106. x = tw.fail_email_subject % {"name": tw.task_name,
  107. "id": tw.task_id,
  108. "exc": "FOOBARBAZ",
  109. "hostname": "lana"}
  110. self.assertTrue(x)
  111. def test_from_message(self):
  112. body = {"task": "cu.mytask", "id": gen_unique_id(),
  113. "args": [2], "kwargs": {u"æØåveéðƒeæ": "bar"}}
  114. m = BaseMessage(body=simplejson.dumps(body), backend="foo",
  115. content_type="application/json",
  116. content_encoding="utf-8")
  117. tw = TaskWrapper.from_message(m, m.decode())
  118. self.assertTrue(isinstance(tw, TaskWrapper))
  119. self.assertEquals(tw.task_name, body["task"])
  120. self.assertEquals(tw.task_id, body["id"])
  121. self.assertEquals(tw.args, body["args"])
  122. self.assertEquals(tw.kwargs.keys()[0],
  123. u"æØåveéðƒeæ".encode("utf-8"))
  124. self.assertFalse(isinstance(tw.kwargs.keys()[0], unicode))
  125. self.assertEquals(id(mytask), id(tw.task_func))
  126. self.assertTrue(tw.logger)
  127. def test_from_message_nonexistant_task(self):
  128. body = {"task": "cu.mytask.doesnotexist", "id": gen_unique_id(),
  129. "args": [2], "kwargs": {u"æØåveéðƒeæ": "bar"}}
  130. m = BaseMessage(body=simplejson.dumps(body), backend="foo",
  131. content_type="application/json",
  132. content_encoding="utf-8")
  133. self.assertRaises(NotRegistered, TaskWrapper.from_message,
  134. m, m.decode())
  135. def test_execute(self):
  136. tid = gen_unique_id()
  137. tw = TaskWrapper("cu.mytask", tid, mytask, [4], {"f": "x"})
  138. self.assertEquals(tw.execute(), 256)
  139. meta = TaskMeta.objects.get(task_id=tid)
  140. self.assertEquals(meta.result, 256)
  141. self.assertEquals(meta.status, "DONE")
  142. def test_execute_ack(self):
  143. tid = gen_unique_id()
  144. tw = TaskWrapper("cu.mytask", tid, mytask, [4], {"f": "x"},
  145. on_ack=on_ack)
  146. self.assertEquals(tw.execute(), 256)
  147. meta = TaskMeta.objects.get(task_id=tid)
  148. self.assertTrue(scratch["ACK"])
  149. self.assertEquals(meta.result, 256)
  150. self.assertEquals(meta.status, "DONE")
  151. def test_execute_fail(self):
  152. tid = gen_unique_id()
  153. tw = TaskWrapper("cu.mytask-raising", tid, mytask_raising, [4],
  154. {"f": "x"})
  155. self.assertTrue(isinstance(tw.execute(), ExceptionInfo))
  156. meta = TaskMeta.objects.get(task_id=tid)
  157. self.assertEquals(meta.status, "FAILURE")
  158. self.assertTrue(isinstance(meta.result, KeyError))
  159. def test_execute_using_pool(self):
  160. tid = gen_unique_id()
  161. tw = TaskWrapper("cu.mytask", tid, mytask, [4], {"f": "x"})
  162. p = TaskPool(2)
  163. p.start()
  164. asyncres = tw.execute_using_pool(p)
  165. self.assertTrue(asyncres.get(), 256)
  166. p.stop()
  167. def test_default_kwargs(self):
  168. tid = gen_unique_id()
  169. tw = TaskWrapper("cu.mytask", tid, mytask, [4], {"f": "x"})
  170. self.assertEquals(tw.extend_with_default_kwargs(10, "some_logfile"), {
  171. "f": "x",
  172. "logfile": "some_logfile",
  173. "loglevel": 10,
  174. "task_id": tw.task_id,
  175. "task_name": tw.task_name})
  176. def test_on_failure(self):
  177. tid = gen_unique_id()
  178. tw = TaskWrapper("cu.mytask", tid, mytask, [4], {"f": "x"})
  179. try:
  180. raise Exception("Inside unit tests")
  181. except Exception:
  182. exc_info = ExceptionInfo(sys.exc_info())
  183. logfh = StringIO()
  184. tw.logger.handlers = []
  185. tw.logger = setup_logger(logfile=logfh, loglevel=logging.INFO)
  186. from celery import conf
  187. conf.SEND_CELERY_TASK_ERROR_EMAILS = True
  188. tw.on_failure(exc_info, {"task_id": tid, "task_name": "cu.mytask"})
  189. logvalue = logfh.getvalue()
  190. self.assertTrue("cu.mytask" in logvalue)
  191. self.assertTrue(tid in logvalue)
  192. self.assertTrue("ERROR" in logvalue)
  193. conf.SEND_CELERY_TASK_ERROR_EMAILS = False