test_task.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. import unittest
  2. import uuid
  3. import logging
  4. from StringIO import StringIO
  5. from celery import task
  6. from celery import registry
  7. from celery.log import setup_logger
  8. from celery import messaging
  9. from celery.backends import default_backend
  10. def return_True(self, **kwargs):
  11. # Task run functions can't be closures/lambdas, as they're pickled.
  12. return True
  13. def raise_exception(self, **kwargs):
  14. raise Exception("%s error" % self.__class__)
  15. class IncrementCounterTask(task.Task):
  16. name = "c.unittest.increment_counter_task"
  17. count = 0
  18. def run(self, increment_by, **kwargs):
  19. increment_by = increment_by or 1
  20. self.__class__.count += increment_by
  21. class TestCeleryTasks(unittest.TestCase):
  22. def createTaskCls(self, cls_name, task_name=None):
  23. attrs = {}
  24. if task_name:
  25. attrs["name"] = task_name
  26. cls = type(cls_name, (task.Task, ), attrs)
  27. cls.run = return_True
  28. return cls
  29. def assertNextTaskDataEquals(self, consumer, presult, task_name,
  30. **kwargs):
  31. next_task = consumer.fetch()
  32. task_data = consumer.decoder(next_task.body)
  33. self.assertEquals(task_data["id"], presult.task_id)
  34. self.assertEquals(task_data["task"], task_name)
  35. task_kwargs = task_data.get("kwargs", {})
  36. for arg_name, arg_value in kwargs.items():
  37. self.assertEquals(task_kwargs.get(arg_name), arg_value)
  38. def test_incomplete_task_cls(self):
  39. class IncompleteTask(task.Task):
  40. name = "c.unittest.t.itask"
  41. self.assertRaises(NotImplementedError, IncompleteTask().run)
  42. def test_regular_task(self):
  43. T1 = self.createTaskCls("T1", "c.unittest.t.t1")
  44. self.assertTrue(isinstance(T1(), T1))
  45. self.assertTrue(T1().run())
  46. self.assertTrue(callable(T1()),
  47. "Task class is callable()")
  48. self.assertTrue(T1()(),
  49. "Task class runs run() when called")
  50. # task without name raises NotImplementedError
  51. T2 = self.createTaskCls("T2")
  52. self.assertRaises(NotImplementedError, T2)
  53. registry.tasks.register(T1)
  54. t1 = T1()
  55. consumer = t1.get_consumer()
  56. self.assertRaises(NotImplementedError, consumer.receive, "foo", "foo")
  57. consumer.discard_all()
  58. self.assertTrue(consumer.fetch() is None)
  59. # Without arguments.
  60. presult = t1.delay()
  61. self.assertNextTaskDataEquals(consumer, presult, t1.name)
  62. # With arguments.
  63. presult2 = task.delay_task(t1.name, name="George Constanza")
  64. self.assertNextTaskDataEquals(consumer, presult2, t1.name,
  65. name="George Constanza")
  66. self.assertRaises(registry.tasks.NotRegistered, task.delay_task,
  67. "some.task.that.should.never.exist.X.X.X.X.X")
  68. # Discarding all tasks.
  69. task.discard_all()
  70. tid3 = task.delay_task(t1.name)
  71. self.assertEquals(task.discard_all(), 1)
  72. self.assertTrue(consumer.fetch() is None)
  73. self.assertFalse(task.is_done(presult.task_id))
  74. self.assertFalse(presult.is_done())
  75. default_backend.mark_as_done(presult.task_id, result=None)
  76. self.assertTrue(task.is_done(presult.task_id))
  77. self.assertTrue(presult.is_done())
  78. publisher = t1.get_publisher()
  79. self.assertTrue(isinstance(publisher, messaging.TaskPublisher))
  80. class TestTaskSet(unittest.TestCase):
  81. def test_counter_taskset(self):
  82. ts = task.TaskSet(IncrementCounterTask, [
  83. [[], {}],
  84. [[], {"increment_by": 2}],
  85. [[], {"increment_by": 3}],
  86. [[], {"increment_by": 4}],
  87. [[], {"increment_by": 5}],
  88. [[], {"increment_by": 6}],
  89. [[], {"increment_by": 7}],
  90. [[], {"increment_by": 8}],
  91. [[], {"increment_by": 9}],
  92. ])
  93. self.assertEquals(ts.task_name, IncrementCounterTask.name)
  94. self.assertEquals(ts.total, 9)
  95. taskset_id, subtask_ids = ts.run()
  96. consumer = IncrementCounterTask().get_consumer()
  97. for subtask_id in subtask_ids:
  98. m = consumer.decoder(consumer.fetch().body)
  99. self.assertEquals(m.get("taskset"), taskset_id)
  100. self.assertEquals(m.get("task"), IncrementCounterTask.name)
  101. self.assertEquals(m.get("id"), subtask_id)
  102. IncrementCounterTask().run(
  103. increment_by=m.get("kwargs", {}).get("increment_by"))
  104. self.assertEquals(IncrementCounterTask.count, sum(xrange(1, 10)))