test_canvas.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. from __future__ import absolute_import, unicode_literals
  2. import pytest
  3. from flaky import flaky
  4. from celery import chain, chord, group
  5. from celery.exceptions import TimeoutError
  6. from celery.result import AsyncResult, GroupResult
  7. from .tasks import add, collect_ids, ids
  8. TIMEOUT = 120
  9. @flaky
  10. class test_chain:
  11. def test_simple_chain(self, manager):
  12. c = add.s(4, 4) | add.s(8) | add.s(16)
  13. assert c().get(timeout=TIMEOUT) == 32
  14. def test_complex_chain(self, manager):
  15. c = (
  16. add.s(2, 2) | (
  17. add.s(4) | add.s(8) | add.s(16)
  18. ) |
  19. group(add.s(i) for i in range(4))
  20. )
  21. res = c()
  22. assert res.get(timeout=TIMEOUT) == [32, 33, 34, 35]
  23. def test_parent_ids(self, manager, num=10):
  24. assert manager.inspect().ping()
  25. c = chain(ids.si(i=i) for i in range(num))
  26. c.freeze()
  27. res = c()
  28. try:
  29. res.get(timeout=TIMEOUT)
  30. except TimeoutError:
  31. print(manager.inspect.active())
  32. print(manager.inspect.reserved())
  33. print(manager.inspect.stats())
  34. raise
  35. self.assert_ids(res, num - 1)
  36. def assert_ids(self, res, size):
  37. i, root = size, res
  38. while root.parent:
  39. root = root.parent
  40. node = res
  41. while node:
  42. root_id, parent_id, value = node.get(timeout=30)
  43. assert value == i
  44. if node.parent:
  45. assert parent_id == node.parent.id
  46. assert root_id == root.id
  47. node = node.parent
  48. i -= 1
  49. @flaky
  50. class test_group:
  51. def test_parent_ids(self, manager):
  52. assert manager.inspect().ping()
  53. g = (
  54. ids.si(i=1) |
  55. ids.si(i=2) |
  56. group(ids.si(i=i) for i in range(2, 50))
  57. )
  58. res = g()
  59. expected_root_id = res.parent.parent.id
  60. expected_parent_id = res.parent.id
  61. values = res.get(timeout=TIMEOUT)
  62. for i, r in enumerate(values):
  63. root_id, parent_id, value = r
  64. assert root_id == expected_root_id
  65. assert parent_id == expected_parent_id
  66. assert value == i + 2
  67. def assert_ids(r, expected_value, expected_root_id, expected_parent_id):
  68. root_id, parent_id, value = r.get(timeout=TIMEOUT)
  69. assert expected_value == value
  70. assert root_id == expected_root_id
  71. assert parent_id == expected_parent_id
  72. @flaky
  73. class test_chord:
  74. def test_parent_ids(self, manager):
  75. if not manager.app.conf.result_backend.startswith('redis'):
  76. raise pytest.skip('Requires redis result backend.')
  77. root = ids.si(i=1)
  78. expected_root_id = root.freeze().id
  79. g = chain(
  80. root, ids.si(i=2),
  81. chord(
  82. group(ids.si(i=i) for i in range(3, 50)),
  83. chain(collect_ids.s(i=50) | ids.si(i=51)),
  84. ),
  85. )
  86. self.assert_parentids_chord(g(), expected_root_id)
  87. def test_parent_ids__OR(self, manager):
  88. if not manager.app.conf.result_backend.startswith('redis'):
  89. raise pytest.skip('Requires redis result backend.')
  90. root = ids.si(i=1)
  91. expected_root_id = root.freeze().id
  92. g = (
  93. root |
  94. ids.si(i=2) |
  95. group(ids.si(i=i) for i in range(3, 50)) |
  96. collect_ids.s(i=50) |
  97. ids.si(i=51)
  98. )
  99. self.assert_parentids_chord(g(), expected_root_id)
  100. def assert_parentids_chord(self, res, expected_root_id):
  101. assert isinstance(res, AsyncResult)
  102. assert isinstance(res.parent, AsyncResult)
  103. assert isinstance(res.parent.parent, GroupResult)
  104. assert isinstance(res.parent.parent.parent, AsyncResult)
  105. assert isinstance(res.parent.parent.parent.parent, AsyncResult)
  106. # first we check the last task
  107. assert_ids(res, 51, expected_root_id, res.parent.id)
  108. # then the chord callback
  109. prev, (root_id, parent_id, value) = res.parent.get(timeout=30)
  110. assert value == 50
  111. assert root_id == expected_root_id
  112. # started by one of the chord header tasks.
  113. assert parent_id in res.parent.parent.results
  114. # check what the chord callback recorded
  115. for i, p in enumerate(prev):
  116. root_id, parent_id, value = p
  117. assert root_id == expected_root_id
  118. assert parent_id == res.parent.parent.parent.id
  119. # ids(i=2)
  120. root_id, parent_id, value = res.parent.parent.parent.get(timeout=30)
  121. assert value == 2
  122. assert parent_id == res.parent.parent.parent.parent.id
  123. assert root_id == expected_root_id
  124. # ids(i=1)
  125. root_id, parent_id, value = res.parent.parent.parent.parent.get(
  126. timeout=30)
  127. assert value == 1
  128. assert root_id == expected_root_id
  129. assert parent_id is None