test_task_sets.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. import anyjson
  2. import warnings
  3. from celery.app import app_or_default
  4. from celery.task import Task
  5. from celery.task.sets import subtask, TaskSet
  6. from celery.tests.utils import unittest
  7. from celery.tests.utils import execute_context
  8. from celery.tests.compat import catch_warnings
  9. class MockTask(Task):
  10. name = "tasks.add"
  11. def run(self, x, y, **kwargs):
  12. return x + y
  13. @classmethod
  14. def apply_async(cls, args, kwargs, **options):
  15. return (args, kwargs, options)
  16. @classmethod
  17. def apply(cls, args, kwargs, **options):
  18. return (args, kwargs, options)
  19. class test_subtask(unittest.TestCase):
  20. def test_behaves_like_type(self):
  21. s = subtask("tasks.add", (2, 2), {"cache": True},
  22. {"routing_key": "CPU-bound"})
  23. self.assertDictEqual(subtask(s), s)
  24. def test_task_argument_can_be_task_cls(self):
  25. s = subtask(MockTask, (2, 2))
  26. self.assertEqual(s.task, MockTask.name)
  27. def test_apply_async(self):
  28. s = MockTask.subtask((2, 2), {"cache": True},
  29. {"routing_key": "CPU-bound"})
  30. args, kwargs, options = s.apply_async()
  31. self.assertTupleEqual(args, (2, 2))
  32. self.assertDictEqual(kwargs, {"cache": True})
  33. self.assertDictEqual(options, {"routing_key": "CPU-bound"})
  34. def test_delay_argmerge(self):
  35. s = MockTask.subtask((2, ), {"cache": True},
  36. {"routing_key": "CPU-bound"})
  37. args, kwargs, options = s.delay(10, cache=False, other="foo")
  38. self.assertTupleEqual(args, (10, 2))
  39. self.assertDictEqual(kwargs, {"cache": False, "other": "foo"})
  40. self.assertDictEqual(options, {"routing_key": "CPU-bound"})
  41. def test_apply_async_argmerge(self):
  42. s = MockTask.subtask((2, ), {"cache": True},
  43. {"routing_key": "CPU-bound"})
  44. args, kwargs, options = s.apply_async((10, ),
  45. {"cache": False, "other": "foo"},
  46. routing_key="IO-bound",
  47. exchange="fast")
  48. self.assertTupleEqual(args, (10, 2))
  49. self.assertDictEqual(kwargs, {"cache": False, "other": "foo"})
  50. self.assertDictEqual(options, {"routing_key": "IO-bound",
  51. "exchange": "fast"})
  52. def test_apply_argmerge(self):
  53. s = MockTask.subtask((2, ), {"cache": True},
  54. {"routing_key": "CPU-bound"})
  55. args, kwargs, options = s.apply((10, ),
  56. {"cache": False, "other": "foo"},
  57. routing_key="IO-bound",
  58. exchange="fast")
  59. self.assertTupleEqual(args, (10, 2))
  60. self.assertDictEqual(kwargs, {"cache": False, "other": "foo"})
  61. self.assertDictEqual(options, {"routing_key": "IO-bound",
  62. "exchange": "fast"})
  63. def test_is_JSON_serializable(self):
  64. s = MockTask.subtask((2, ), {"cache": True},
  65. {"routing_key": "CPU-bound"})
  66. s.args = list(s.args) # tuples are not preserved
  67. # but this doesn't matter.
  68. self.assertEqual(s,
  69. subtask(anyjson.deserialize(
  70. anyjson.serialize(s))))
  71. class test_TaskSet(unittest.TestCase):
  72. def test_interface__compat(self):
  73. warnings.resetwarnings()
  74. def with_catch_warnings(log):
  75. ts = TaskSet(MockTask, [[(2, 2)], [(4, 4)], [(8, 8)]])
  76. self.assertTrue(log)
  77. self.assertIn("Using this invocation of TaskSet is deprecated",
  78. log[0].message.args[0])
  79. self.assertListEqual(ts.tasks,
  80. [MockTask.subtask((i, i))
  81. for i in (2, 4, 8)])
  82. return ts
  83. context = catch_warnings(record=True)
  84. execute_context(context, with_catch_warnings)
  85. # TaskSet.task (deprecated)
  86. def with_catch_warnings2(log):
  87. ts = TaskSet(MockTask, [[(2, 2)], [(4, 4)], [(8, 8)]])
  88. self.assertEqual(ts.task.name, MockTask.name)
  89. self.assertTrue(log)
  90. self.assertIn("TaskSet.task is deprecated",
  91. log[0].message.args[0])
  92. execute_context(catch_warnings(record=True), with_catch_warnings2)
  93. # TaskSet.task_name (deprecated)
  94. def with_catch_warnings3(log):
  95. ts = TaskSet(MockTask, [[(2, 2)], [(4, 4)], [(8, 8)]])
  96. self.assertEqual(ts.task_name, MockTask.name)
  97. self.assertTrue(log)
  98. self.assertIn("TaskSet.task_name is deprecated",
  99. log[0].message.args[0])
  100. execute_context(catch_warnings(record=True), with_catch_warnings3)
  101. def test_task_arg_can_be_iterable__compat(self):
  102. ts = TaskSet([MockTask.subtask((i, i))
  103. for i in (2, 4, 8)])
  104. self.assertEqual(len(ts), 3)
  105. def test_respects_ALWAYS_EAGER(self):
  106. app = app_or_default()
  107. class MockTaskSet(TaskSet):
  108. applied = 0
  109. def apply(self, *args, **kwargs):
  110. self.applied += 1
  111. ts = MockTaskSet([MockTask.subtask((i, i))
  112. for i in (2, 4, 8)])
  113. app.conf.CELERY_ALWAYS_EAGER = True
  114. try:
  115. ts.apply_async()
  116. finally:
  117. app.conf.CELERY_ALWAYS_EAGER = False
  118. self.assertEqual(ts.applied, 1)
  119. def test_apply_async(self):
  120. applied = [0]
  121. class mocksubtask(subtask):
  122. def apply_async(self, *args, **kwargs):
  123. applied[0] += 1
  124. ts = TaskSet([mocksubtask(MockTask, (i, i))
  125. for i in (2, 4, 8)])
  126. ts.apply_async()
  127. self.assertEqual(applied[0], 3)
  128. def test_apply(self):
  129. applied = [0]
  130. class mocksubtask(subtask):
  131. def apply(self, *args, **kwargs):
  132. applied[0] += 1
  133. ts = TaskSet([mocksubtask(MockTask, (i, i))
  134. for i in (2, 4, 8)])
  135. ts.apply()
  136. self.assertEqual(applied[0], 3)