test_task_sets.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import unittest2 as unittest
  2. import simplejson
  3. from celery import conf
  4. from celery.task import Task
  5. from celery.task.sets import subtask, TaskSet
  6. from celery.tests.utils import execute_context
  7. from celery.tests.compat import catch_warnings
  8. class MockTask(Task):
  9. name = "tasks.add"
  10. def run(self, x, y, **kwargs):
  11. return x + y
  12. @classmethod
  13. def apply_async(cls, args, kwargs, **options):
  14. return (args, kwargs, options)
  15. @classmethod
  16. def apply(cls, args, kwargs, **options):
  17. return (args, kwargs, options)
  18. class test_subtask(unittest.TestCase):
  19. def test_behaves_like_type(self):
  20. s = subtask("tasks.add", (2, 2), {"cache": True},
  21. {"routing_key": "CPU-bound"})
  22. self.assertDictEqual(subtask(s), s)
  23. def test_task_argument_can_be_task_cls(self):
  24. s = subtask(MockTask, (2, 2))
  25. self.assertEqual(s.task, MockTask.name)
  26. def test_apply_async(self):
  27. s = MockTask.subtask((2, 2), {"cache": True},
  28. {"routing_key": "CPU-bound"})
  29. args, kwargs, options = s.apply_async()
  30. self.assertTupleEqual(args, (2, 2))
  31. self.assertDictEqual(kwargs, {"cache": True})
  32. self.assertDictEqual(options, {"routing_key": "CPU-bound"})
  33. def test_delay_argmerge(self):
  34. s = MockTask.subtask((2, ), {"cache": True},
  35. {"routing_key": "CPU-bound"})
  36. args, kwargs, options = s.delay(10, cache=False, other="foo")
  37. self.assertTupleEqual(args, (10, 2))
  38. self.assertDictEqual(kwargs, {"cache": False, "other": "foo"})
  39. self.assertDictEqual(options, {"routing_key": "CPU-bound"})
  40. def test_apply_async_argmerge(self):
  41. s = MockTask.subtask((2, ), {"cache": True},
  42. {"routing_key": "CPU-bound"})
  43. args, kwargs, options = s.apply_async((10, ),
  44. {"cache": False, "other": "foo"},
  45. routing_key="IO-bound",
  46. exchange="fast")
  47. self.assertTupleEqual(args, (10, 2))
  48. self.assertDictEqual(kwargs, {"cache": False, "other": "foo"})
  49. self.assertDictEqual(options, {"routing_key": "IO-bound",
  50. "exchange": "fast"})
  51. def test_apply_argmerge(self):
  52. s = MockTask.subtask((2, ), {"cache": True},
  53. {"routing_key": "CPU-bound"})
  54. args, kwargs, options = s.apply((10, ),
  55. {"cache": False, "other": "foo"},
  56. routing_key="IO-bound",
  57. exchange="fast")
  58. self.assertTupleEqual(args, (10, 2))
  59. self.assertDictEqual(kwargs, {"cache": False, "other": "foo"})
  60. self.assertDictEqual(options, {"routing_key": "IO-bound",
  61. "exchange": "fast"})
  62. def test_is_JSON_serializable(self):
  63. s = MockTask.subtask((2, ), {"cache": True},
  64. {"routing_key": "CPU-bound"})
  65. s.args = list(s.args) # tuples are not preserved
  66. # but this doesn't matter.
  67. self.assertEqual(s,
  68. subtask(simplejson.loads(simplejson.dumps(s))))
  69. class test_TaskSet(unittest.TestCase):
  70. def test_interface__compat(self):
  71. def with_catch_warnings(log):
  72. ts = TaskSet(MockTask, [[(2, 2)], [(4, 4)], [(8, 8)]])
  73. self.assertTrue(log)
  74. self.assertIn("Using this invocation of TaskSet is deprecated",
  75. log[0].message.args[0])
  76. self.assertListEqual(ts.tasks,
  77. [MockTask.subtask((i, i))
  78. for i in (2, 4, 8)])
  79. return ts
  80. context = catch_warnings(record=True)
  81. execute_context(context, with_catch_warnings)
  82. # TaskSet.task (deprecated)
  83. def with_catch_warnings2(log):
  84. ts = TaskSet(MockTask, [[(2, 2)], [(4, 4)], [(8, 8)]])
  85. self.assertEqual(ts.task, MockTask)
  86. self.assertTrue(log)
  87. self.assertIn("TaskSet.task is deprecated",
  88. log[0].message.args[0])
  89. execute_context(catch_warnings(record=True), with_catch_warnings2)
  90. # TaskSet.task_name (deprecated)
  91. def with_catch_warnings3(log):
  92. ts = TaskSet(MockTask, [[(2, 2)], [(4, 4)], [(8, 8)]])
  93. self.assertEqual(ts.task_name, MockTask.name)
  94. self.assertTrue(log)
  95. self.assertIn("TaskSet.task_name is deprecated",
  96. log[0].message.args[0])
  97. execute_context(catch_warnings(record=True), with_catch_warnings3)
  98. def test_task_arg_can_be_iterable__compat(self):
  99. ts = TaskSet([MockTask.subtask((i, i))
  100. for i in (2, 4, 8)])
  101. self.assertEqual(len(ts), 3)
  102. def test_respects_ALWAYS_EAGER(self):
  103. class MockTaskSet(TaskSet):
  104. applied = 0
  105. def apply(self, *args, **kwargs):
  106. self.applied += 1
  107. ts = MockTaskSet([MockTask.subtask((i, i))
  108. for i in (2, 4, 8)])
  109. conf.ALWAYS_EAGER = True
  110. try:
  111. ts.apply_async()
  112. finally:
  113. conf.ALWAYS_EAGER = False
  114. self.assertEqual(ts.applied, 1)
  115. def test_apply_async(self):
  116. applied = [0]
  117. class mocksubtask(subtask):
  118. def apply_async(self, *args, **kwargs):
  119. applied[0] += 1
  120. ts = TaskSet([mocksubtask(MockTask, (i, i))
  121. for i in (2, 4, 8)])
  122. ts.apply_async()
  123. self.assertEqual(applied[0], 3)
  124. def test_apply(self):
  125. applied = [0]
  126. class mocksubtask(subtask):
  127. def apply(self, *args, **kwargs):
  128. applied[0] += 1
  129. ts = TaskSet([mocksubtask(MockTask, (i, i))
  130. for i in (2, 4, 8)])
  131. ts.apply()
  132. self.assertEqual(applied[0], 3)