test_canvas.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. from __future__ import absolute_import, unicode_literals
  2. import pytest
  3. from celery import chain, chord, group
  4. from celery.exceptions import TimeoutError
  5. from celery.result import AsyncResult, GroupResult
  6. from .conftest import flaky
  7. from .tasks import add, collect_ids, ids
  8. TIMEOUT = 120
  9. class test_chain:
  10. @flaky
  11. def test_simple_chain(self, manager):
  12. c = add.s(4, 4) | add.s(8) | add.s(16)
  13. assert c().get(timeout=TIMEOUT) == 32
  14. @flaky
  15. def test_complex_chain(self, manager):
  16. c = (
  17. add.s(2, 2) | (
  18. add.s(4) | add.s(8) | add.s(16)
  19. ) |
  20. group(add.s(i) for i in range(4))
  21. )
  22. res = c()
  23. assert res.get(timeout=TIMEOUT) == [32, 33, 34, 35]
  24. @flaky
  25. def test_parent_ids(self, manager, num=10):
  26. assert manager.inspect().ping()
  27. c = chain(ids.si(i=i) for i in range(num))
  28. c.freeze()
  29. res = c()
  30. try:
  31. res.get(timeout=TIMEOUT)
  32. except TimeoutError:
  33. print(manager.inspect.active())
  34. print(manager.inspect.reserved())
  35. print(manager.inspect.stats())
  36. raise
  37. self.assert_ids(res, num - 1)
  38. def assert_ids(self, res, size):
  39. i, root = size, res
  40. while root.parent:
  41. root = root.parent
  42. node = res
  43. while node:
  44. root_id, parent_id, value = node.get(timeout=30)
  45. assert value == i
  46. if node.parent:
  47. assert parent_id == node.parent.id
  48. assert root_id == root.id
  49. node = node.parent
  50. i -= 1
  51. class test_group:
  52. @flaky
  53. def test_parent_ids(self, manager):
  54. assert manager.inspect().ping()
  55. g = (
  56. ids.si(i=1) |
  57. ids.si(i=2) |
  58. group(ids.si(i=i) for i in range(2, 50))
  59. )
  60. res = g()
  61. expected_root_id = res.parent.parent.id
  62. expected_parent_id = res.parent.id
  63. values = res.get(timeout=TIMEOUT)
  64. for i, r in enumerate(values):
  65. root_id, parent_id, value = r
  66. assert root_id == expected_root_id
  67. assert parent_id == expected_parent_id
  68. assert value == i + 2
  69. def assert_ids(r, expected_value, expected_root_id, expected_parent_id):
  70. root_id, parent_id, value = r.get(timeout=TIMEOUT)
  71. assert expected_value == value
  72. assert root_id == expected_root_id
  73. assert parent_id == expected_parent_id
  74. class test_chord:
  75. @flaky
  76. def test_parent_ids(self, manager):
  77. if not manager.app.conf.result_backend.startswith('redis'):
  78. raise pytest.skip('Requires redis result backend.')
  79. root = ids.si(i=1)
  80. expected_root_id = root.freeze().id
  81. g = chain(
  82. root, ids.si(i=2),
  83. chord(
  84. group(ids.si(i=i) for i in range(3, 50)),
  85. chain(collect_ids.s(i=50) | ids.si(i=51)),
  86. ),
  87. )
  88. self.assert_parentids_chord(g(), expected_root_id)
  89. @flaky
  90. def test_parent_ids__OR(self, manager):
  91. if not manager.app.conf.result_backend.startswith('redis'):
  92. raise pytest.skip('Requires redis result backend.')
  93. root = ids.si(i=1)
  94. expected_root_id = root.freeze().id
  95. g = (
  96. root |
  97. ids.si(i=2) |
  98. group(ids.si(i=i) for i in range(3, 50)) |
  99. collect_ids.s(i=50) |
  100. ids.si(i=51)
  101. )
  102. self.assert_parentids_chord(g(), expected_root_id)
  103. def assert_parentids_chord(self, res, expected_root_id):
  104. assert isinstance(res, AsyncResult)
  105. assert isinstance(res.parent, AsyncResult)
  106. assert isinstance(res.parent.parent, GroupResult)
  107. assert isinstance(res.parent.parent.parent, AsyncResult)
  108. assert isinstance(res.parent.parent.parent.parent, AsyncResult)
  109. # first we check the last task
  110. assert_ids(res, 51, expected_root_id, res.parent.id)
  111. # then the chord callback
  112. prev, (root_id, parent_id, value) = res.parent.get(timeout=30)
  113. assert value == 50
  114. assert root_id == expected_root_id
  115. # started by one of the chord header tasks.
  116. assert parent_id in res.parent.parent.results
  117. # check what the chord callback recorded
  118. for i, p in enumerate(prev):
  119. root_id, parent_id, value = p
  120. assert root_id == expected_root_id
  121. assert parent_id == res.parent.parent.parent.id
  122. # ids(i=2)
  123. root_id, parent_id, value = res.parent.parent.parent.get(timeout=30)
  124. assert value == 2
  125. assert parent_id == res.parent.parent.parent.parent.id
  126. assert root_id == expected_root_id
  127. # ids(i=1)
  128. root_id, parent_id, value = res.parent.parent.parent.parent.get(
  129. timeout=30)
  130. assert value == 1
  131. assert root_id == expected_root_id
  132. assert parent_id is None