test_task_sets.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. from __future__ import absolute_import
  2. from __future__ import with_statement
  3. import anyjson
  4. from celery import registry
  5. from celery.app import app_or_default
  6. from celery.exceptions import CDeprecationWarning
  7. from celery.task import Task
  8. from celery.task.sets import subtask, TaskSet
  9. from celery.tests.utils import Case
  10. class MockTask(Task):
  11. name = "tasks.add"
  12. def run(self, x, y, **kwargs):
  13. return x + y
  14. @classmethod
  15. def apply_async(cls, args, kwargs, **options):
  16. return (args, kwargs, options)
  17. @classmethod
  18. def apply(cls, args, kwargs, **options):
  19. return (args, kwargs, options)
  20. class test_subtask(Case):
  21. def test_behaves_like_type(self):
  22. s = subtask("tasks.add", (2, 2), {"cache": True},
  23. {"routing_key": "CPU-bound"})
  24. self.assertDictEqual(subtask(s), s)
  25. def test_task_argument_can_be_task_cls(self):
  26. s = subtask(MockTask, (2, 2))
  27. self.assertEqual(s.task, MockTask.name)
  28. def test_apply_async(self):
  29. s = MockTask.subtask((2, 2), {"cache": True},
  30. {"routing_key": "CPU-bound"})
  31. args, kwargs, options = s.apply_async()
  32. self.assertTupleEqual(args, (2, 2))
  33. self.assertDictEqual(kwargs, {"cache": True})
  34. self.assertDictEqual(options, {"routing_key": "CPU-bound"})
  35. def test_delay_argmerge(self):
  36. s = MockTask.subtask((2, ), {"cache": True},
  37. {"routing_key": "CPU-bound"})
  38. args, kwargs, options = s.delay(10, cache=False, other="foo")
  39. self.assertTupleEqual(args, (10, 2))
  40. self.assertDictEqual(kwargs, {"cache": False, "other": "foo"})
  41. self.assertDictEqual(options, {"routing_key": "CPU-bound"})
  42. def test_apply_async_argmerge(self):
  43. s = MockTask.subtask((2, ), {"cache": True},
  44. {"routing_key": "CPU-bound"})
  45. args, kwargs, options = s.apply_async((10, ),
  46. {"cache": False, "other": "foo"},
  47. routing_key="IO-bound",
  48. exchange="fast")
  49. self.assertTupleEqual(args, (10, 2))
  50. self.assertDictEqual(kwargs, {"cache": False, "other": "foo"})
  51. self.assertDictEqual(options, {"routing_key": "IO-bound",
  52. "exchange": "fast"})
  53. def test_apply_argmerge(self):
  54. s = MockTask.subtask((2, ), {"cache": True},
  55. {"routing_key": "CPU-bound"})
  56. args, kwargs, options = s.apply((10, ),
  57. {"cache": False, "other": "foo"},
  58. routing_key="IO-bound",
  59. exchange="fast")
  60. self.assertTupleEqual(args, (10, 2))
  61. self.assertDictEqual(kwargs, {"cache": False, "other": "foo"})
  62. self.assertDictEqual(options, {"routing_key": "IO-bound",
  63. "exchange": "fast"})
  64. def test_is_JSON_serializable(self):
  65. s = MockTask.subtask((2, ), {"cache": True},
  66. {"routing_key": "CPU-bound"})
  67. s.args = list(s.args) # tuples are not preserved
  68. # but this doesn't matter.
  69. self.assertEqual(s,
  70. subtask(anyjson.deserialize(
  71. anyjson.serialize(s))))
  72. def test_repr(self):
  73. s = MockTask.subtask((2, ), {"cache": True})
  74. self.assertIn("2", repr(s))
  75. self.assertIn("cache=True", repr(s))
  76. def test_reduce(self):
  77. s = MockTask.subtask((2, ), {"cache": True})
  78. cls, args, _ = s.__reduce__()
  79. self.assertDictEqual(dict(cls(*args)), dict(s))
  80. class test_TaskSet(Case):
  81. def test_interface__compat(self):
  82. with self.assertWarnsRegex(CDeprecationWarning,
  83. r'Using this invocation of TaskSet is deprecated'):
  84. ts = TaskSet(MockTask, [[(2, 2)], [(4, 4)], [(8, 8)]])
  85. self.assertListEqual(ts.tasks,
  86. [MockTask.subtask((i, i))
  87. for i in (2, 4, 8)])
  88. with self.assertWarnsRegex(CDeprecationWarning,
  89. r'TaskSet.task is deprecated'):
  90. self.assertEqual(ts.task, registry.tasks[MockTask.name])
  91. with self.assertWarnsRegex(CDeprecationWarning,
  92. r'TaskSet.task_name is deprecated'):
  93. self.assertEqual(ts.task_name, MockTask.name)
  94. def test_task_arg_can_be_iterable__compat(self):
  95. ts = TaskSet([MockTask.subtask((i, i))
  96. for i in (2, 4, 8)])
  97. self.assertEqual(len(ts), 3)
  98. def test_respects_ALWAYS_EAGER(self):
  99. app = app_or_default()
  100. class MockTaskSet(TaskSet):
  101. applied = 0
  102. def apply(self, *args, **kwargs):
  103. self.applied += 1
  104. ts = MockTaskSet([MockTask.subtask((i, i))
  105. for i in (2, 4, 8)])
  106. app.conf.CELERY_ALWAYS_EAGER = True
  107. try:
  108. ts.apply_async()
  109. finally:
  110. app.conf.CELERY_ALWAYS_EAGER = False
  111. self.assertEqual(ts.applied, 1)
  112. def test_apply_async(self):
  113. applied = [0]
  114. class mocksubtask(subtask):
  115. def apply_async(self, *args, **kwargs):
  116. applied[0] += 1
  117. ts = TaskSet([mocksubtask(MockTask, (i, i))
  118. for i in (2, 4, 8)])
  119. ts.apply_async()
  120. self.assertEqual(applied[0], 3)
  121. class Publisher(object):
  122. def send(self, *args, **kwargs):
  123. pass
  124. ts.apply_async(publisher=Publisher())
  125. def test_apply(self):
  126. applied = [0]
  127. class mocksubtask(subtask):
  128. def apply(self, *args, **kwargs):
  129. applied[0] += 1
  130. ts = TaskSet([mocksubtask(MockTask, (i, i))
  131. for i in (2, 4, 8)])
  132. ts.apply()
  133. self.assertEqual(applied[0], 3)