test_task.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. import unittest
  2. import uuid
  3. import logging
  4. from StringIO import StringIO
  5. from celery import task
  6. from celery import registry
  7. from celery.log import setup_logger
  8. from celery import messaging
  9. from celery.result import EagerResult
  10. from celery.backends import default_backend
  11. from datetime import datetime, timedelta
  12. def return_True(self, **kwargs):
  13. # Task run functions can't be closures/lambdas, as they're pickled.
  14. return True
  15. registry.tasks.register(return_True, "cu.return-true")
  16. def raise_exception(self, **kwargs):
  17. raise Exception("%s error" % self.__class__)
  18. class IncrementCounterTask(task.Task):
  19. name = "c.unittest.increment_counter_task"
  20. count = 0
  21. def run(self, increment_by=1, **kwargs):
  22. increment_by = increment_by or 1
  23. self.__class__.count += increment_by
  24. return self.__class__.count
  25. class RaisingTask(task.Task):
  26. name = "c.unittest.raising_task"
  27. def run(self, **kwargs):
  28. raise KeyError("foo")
  29. class RetryTask(task.Task):
  30. max_retries = 3
  31. iterations = 0
  32. def run(self, arg1, arg2, kwarg=1, **kwargs):
  33. self.__class__.iterations += 1
  34. retries = kwargs["task_retries"]
  35. if retries >= 3:
  36. return arg1
  37. else:
  38. kwargs.update({"kwarg": kwarg})
  39. return self.retry(args=[arg1, arg2], kwargs=kwargs, countdown=0)
  40. class MyCustomException(Exception):
  41. """Random custom exception."""
  42. class RetryTaskCustomExc(task.Task):
  43. max_retries = 3
  44. iterations = 0
  45. def run(self, arg1, arg2, kwarg=1, **kwargs):
  46. self.__class__.iterations += 1
  47. retries = kwargs["task_retries"]
  48. if retries >= 3:
  49. return arg1 + kwarg
  50. else:
  51. try:
  52. raise MyCustomException("Elaine Marie Benes")
  53. except MyCustomException, exc:
  54. kwargs.update({"kwarg": kwarg})
  55. return self.retry(args=[arg1, arg2], kwargs=kwargs,
  56. countdown=0, exc=exc)
  57. class TestTaskRetries(unittest.TestCase):
  58. def test_retry(self):
  59. RetryTask.max_retries = 3
  60. RetryTask.iterations = 0
  61. result = RetryTask.apply([0xFF, 0xFFFF])
  62. self.assertEquals(result.get(), 0xFF)
  63. self.assertEquals(RetryTask.iterations, 4)
  64. def test_retry_with_kwargs(self):
  65. RetryTaskCustomExc.max_retries = 3
  66. RetryTaskCustomExc.iterations = 0
  67. result = RetryTaskCustomExc.apply([0xFF, 0xFFFF], {"kwarg": 0xF})
  68. self.assertEquals(result.get(), 0xFF + 0xF)
  69. self.assertEquals(RetryTaskCustomExc.iterations, 4)
  70. def test_retry_with_custom_exception(self):
  71. RetryTaskCustomExc.max_retries = 2
  72. RetryTaskCustomExc.iterations = 0
  73. result = RetryTaskCustomExc.apply([0xFF, 0xFFFF], {"kwarg": 0xF})
  74. self.assertRaises(MyCustomException,
  75. result.get)
  76. self.assertEquals(RetryTaskCustomExc.iterations, 3)
  77. def test_max_retries_exceeded(self):
  78. RetryTask.max_retries = 2
  79. RetryTask.iterations = 0
  80. result = RetryTask.apply([0xFF, 0xFFFF])
  81. self.assertRaises(RetryTask.MaxRetriesExceededError,
  82. result.get)
  83. self.assertEquals(RetryTask.iterations, 3)
  84. RetryTask.max_retries = 1
  85. RetryTask.iterations = 0
  86. result = RetryTask.apply([0xFF, 0xFFFF])
  87. self.assertRaises(RetryTask.MaxRetriesExceededError,
  88. result.get)
  89. self.assertEquals(RetryTask.iterations, 2)
  90. class TestCeleryTasks(unittest.TestCase):
  91. def createTaskCls(self, cls_name, task_name=None):
  92. attrs = {}
  93. if task_name:
  94. attrs["name"] = task_name
  95. cls = type(cls_name, (task.Task, ), attrs)
  96. cls.run = return_True
  97. return cls
  98. def test_ping(self):
  99. from celery import conf
  100. conf.ALWAYS_EAGER = True
  101. self.assertEquals(task.ping(), 'pong')
  102. conf.ALWAYS_EAGER = False
  103. def test_execute_remote(self):
  104. from celery import conf
  105. conf.ALWAYS_EAGER = True
  106. self.assertEquals(task.execute_remote(return_True, ["foo"]).get(),
  107. True)
  108. conf.ALWAYS_EAGER = False
  109. def test_dmap(self):
  110. from celery import conf
  111. import operator
  112. conf.ALWAYS_EAGER = True
  113. res = task.dmap(operator.add, zip(xrange(10), xrange(10)))
  114. self.assertTrue(res, sum([operator.add(x, x)
  115. for x in xrange(10)]))
  116. conf.ALWAYS_EAGER = False
  117. def test_dmap_async(self):
  118. from celery import conf
  119. import operator
  120. conf.ALWAYS_EAGER = True
  121. res = task.dmap_async(operator.add, zip(xrange(10), xrange(10)))
  122. self.assertTrue(res.get(), sum([operator.add(x, x)
  123. for x in xrange(10)]))
  124. conf.ALWAYS_EAGER = False
  125. def assertNextTaskDataEquals(self, consumer, presult, task_name,
  126. test_eta=False, **kwargs):
  127. next_task = consumer.fetch()
  128. task_data = next_task.decode()
  129. self.assertEquals(task_data["id"], presult.task_id)
  130. self.assertEquals(task_data["task"], task_name)
  131. task_kwargs = task_data.get("kwargs", {})
  132. if test_eta:
  133. self.assertTrue(isinstance(task_data.get("eta"), datetime))
  134. for arg_name, arg_value in kwargs.items():
  135. self.assertEquals(task_kwargs.get(arg_name), arg_value)
  136. def test_incomplete_task_cls(self):
  137. class IncompleteTask(task.Task):
  138. name = "c.unittest.t.itask"
  139. self.assertRaises(NotImplementedError, IncompleteTask().run)
  140. def test_regular_task(self):
  141. T1 = self.createTaskCls("T1", "c.unittest.t.t1")
  142. self.assertTrue(isinstance(T1(), T1))
  143. self.assertTrue(T1().run())
  144. self.assertTrue(callable(T1()),
  145. "Task class is callable()")
  146. self.assertTrue(T1()(),
  147. "Task class runs run() when called")
  148. # task name generated out of class module + name.
  149. T2 = self.createTaskCls("T2")
  150. self.assertEquals(T2().name, "celery.tests.test_task.T2")
  151. registry.tasks.register(T1)
  152. t1 = T1()
  153. consumer = t1.get_consumer()
  154. self.assertRaises(NotImplementedError, consumer.receive, "foo", "foo")
  155. consumer.discard_all()
  156. self.assertTrue(consumer.fetch() is None)
  157. # Without arguments.
  158. presult = t1.delay()
  159. self.assertNextTaskDataEquals(consumer, presult, t1.name)
  160. # With arguments.
  161. presult2 = task.delay_task(t1.name, name="George Constanza")
  162. self.assertNextTaskDataEquals(consumer, presult2, t1.name,
  163. name="George Constanza")
  164. # With eta.
  165. presult2 = task.apply_async(t1, kwargs=dict(name="George Constanza"),
  166. eta=datetime.now() + timedelta(days=1))
  167. self.assertNextTaskDataEquals(consumer, presult2, t1.name,
  168. name="George Constanza", test_eta=True)
  169. # With countdown.
  170. presult2 = task.apply_async(t1, kwargs=dict(name="George Constanza"),
  171. countdown=10)
  172. self.assertNextTaskDataEquals(consumer, presult2, t1.name,
  173. name="George Constanza", test_eta=True)
  174. self.assertRaises(registry.tasks.NotRegistered, task.delay_task,
  175. "some.task.that.should.never.exist.X.X.X.X.X")
  176. # Discarding all tasks.
  177. task.discard_all()
  178. tid3 = task.delay_task(t1.name)
  179. self.assertEquals(task.discard_all(), 1)
  180. self.assertTrue(consumer.fetch() is None)
  181. self.assertFalse(task.is_done(presult.task_id))
  182. self.assertFalse(presult.is_done())
  183. default_backend.mark_as_done(presult.task_id, result=None)
  184. self.assertTrue(task.is_done(presult.task_id))
  185. self.assertTrue(presult.is_done())
  186. publisher = t1.get_publisher()
  187. self.assertTrue(isinstance(publisher, messaging.TaskPublisher))
  188. def test_get_logger(self):
  189. T1 = self.createTaskCls("T1", "c.unittest.t.t1")
  190. t1 = T1()
  191. logfh = StringIO()
  192. logger = t1.get_logger(logfile=logfh, loglevel=0)
  193. self.assertTrue(logger)
  194. class TestTaskSet(unittest.TestCase):
  195. def test_function_taskset(self):
  196. from celery import conf
  197. conf.ALWAYS_EAGER = True
  198. ts = task.TaskSet("cu.return-true", [
  199. [[1], {}], [[2], {}], [[3], {}], [[4], {}], [[5], {}]])
  200. res = ts.run()
  201. self.assertEquals(res.join(), [True, True, True, True, True])
  202. conf.ALWAYS_EAGER = False
  203. def test_counter_taskset(self):
  204. IncrementCounterTask.count = 0
  205. ts = task.TaskSet(IncrementCounterTask, [
  206. [[], {}],
  207. [[], {"increment_by": 2}],
  208. [[], {"increment_by": 3}],
  209. [[], {"increment_by": 4}],
  210. [[], {"increment_by": 5}],
  211. [[], {"increment_by": 6}],
  212. [[], {"increment_by": 7}],
  213. [[], {"increment_by": 8}],
  214. [[], {"increment_by": 9}],
  215. ])
  216. self.assertEquals(ts.task_name, IncrementCounterTask.name)
  217. self.assertEquals(ts.total, 9)
  218. taskset_res = ts.run()
  219. subtasks = taskset_res.subtasks
  220. taskset_id = taskset_res.taskset_id
  221. consumer = IncrementCounterTask().get_consumer()
  222. for subtask in subtasks:
  223. m = consumer.decoder(consumer.fetch().body)
  224. self.assertEquals(m.get("taskset"), taskset_id)
  225. self.assertEquals(m.get("task"), IncrementCounterTask.name)
  226. self.assertEquals(m.get("id"), subtask.task_id)
  227. IncrementCounterTask().run(
  228. increment_by=m.get("kwargs", {}).get("increment_by"))
  229. self.assertEquals(IncrementCounterTask.count, sum(xrange(1, 10)))
  230. class TestTaskApply(unittest.TestCase):
  231. def test_apply(self):
  232. IncrementCounterTask.count = 0
  233. e = IncrementCounterTask.apply()
  234. self.assertTrue(isinstance(e, EagerResult))
  235. self.assertEquals(e.get(), 1)
  236. e = IncrementCounterTask.apply(args=[1])
  237. self.assertEquals(e.get(), 2)
  238. e = IncrementCounterTask.apply(kwargs={"increment_by": 4})
  239. self.assertEquals(e.get(), 6)
  240. self.assertTrue(e.is_done())
  241. self.assertTrue(e.is_ready())
  242. self.assertTrue(repr(e).startswith("<EagerResult:"))
  243. f = RaisingTask.apply()
  244. self.assertTrue(f.is_ready())
  245. self.assertFalse(f.is_done())
  246. self.assertTrue(f.traceback)
  247. self.assertRaises(KeyError, f.get)
  248. class TestPeriodicTask(unittest.TestCase):
  249. def test_interface(self):
  250. class MyPeriodicTask(task.PeriodicTask):
  251. run_every = None
  252. self.assertRaises(NotImplementedError, MyPeriodicTask)