test_canvas.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. import pytest
  2. from celery import chain, chord, group
  3. from celery.exceptions import TimeoutError
  4. from celery.result import AsyncResult, GroupResult
  5. from .tasks import add, collect_ids, ids
  6. TIMEOUT = 120
  7. class test_chain:
  8. def test_simple_chain(self, manager):
  9. c = add.s(4, 4) | add.s(8) | add.s(16)
  10. assert c().get(timeout=TIMEOUT) == 32
  11. def test_complex_chain(self, manager):
  12. c = (
  13. add.s(2, 2) | (
  14. add.s(4) | add.s(8) | add.s(16)
  15. ) |
  16. group(add.s(i) for i in range(4))
  17. )
  18. res = c()
  19. assert res.get(timeout=TIMEOUT) == [32, 33, 34, 35]
  20. def test_parent_ids(self, manager, num=10):
  21. assert manager.inspect().ping()
  22. c = chain(ids.si(i=i) for i in range(num))
  23. c.freeze()
  24. res = c()
  25. try:
  26. res.get(timeout=TIMEOUT)
  27. except TimeoutError:
  28. print(manager.inspect.active())
  29. print(manager.inspect.reserved())
  30. print(manager.inspect.stats())
  31. raise
  32. self.assert_ids(res, num - 1)
  33. def assert_ids(self, res, size):
  34. i, root = size, res
  35. while root.parent:
  36. root = root.parent
  37. node = res
  38. while node:
  39. root_id, parent_id, value = node.get(timeout=30)
  40. assert value == i
  41. if node.parent:
  42. assert parent_id == node.parent.id
  43. assert root_id == root.id
  44. node = node.parent
  45. i -= 1
  46. class test_group:
  47. def test_parent_ids(self, manager):
  48. assert manager.inspect().ping()
  49. g = (
  50. ids.si(i=1) |
  51. ids.si(i=2) |
  52. group(ids.si(i=i) for i in range(2, 50))
  53. )
  54. res = g()
  55. expected_root_id = res.parent.parent.id
  56. expected_parent_id = res.parent.id
  57. values = res.get(timeout=TIMEOUT)
  58. for i, r in enumerate(values):
  59. root_id, parent_id, value = r
  60. assert root_id == expected_root_id
  61. assert parent_id == expected_parent_id
  62. assert value == i + 2
  63. def assert_ids(r, expected_value, expected_root_id, expected_parent_id):
  64. root_id, parent_id, value = r.get(timeout=TIMEOUT)
  65. assert expected_value == value
  66. assert root_id == expected_root_id
  67. assert parent_id == expected_parent_id
  68. @pytest.mark.celery(result_backend='redis://')
  69. class test_chord:
  70. def test_parent_ids(self, manager):
  71. root = ids.si(i=1)
  72. expected_root_id = root.freeze().id
  73. g = chain(
  74. root, ids.si(i=2),
  75. chord(
  76. group(ids.si(i=i) for i in range(3, 50)),
  77. chain(collect_ids.s(i=50) | ids.si(i=51)),
  78. ),
  79. )
  80. self.assert_parentids_chord(g(), expected_root_id)
  81. def test_parent_ids__OR(self, manager):
  82. root = ids.si(i=1)
  83. expected_root_id = root.freeze().id
  84. g = (
  85. root |
  86. ids.si(i=2) |
  87. group(ids.si(i=i) for i in range(3, 50)) |
  88. collect_ids.s(i=50) |
  89. ids.si(i=51)
  90. )
  91. self.assert_parentids_chord(g(), expected_root_id)
  92. def assert_parentids_chord(self, res, expected_root_id):
  93. assert isinstance(res, AsyncResult)
  94. assert isinstance(res.parent, AsyncResult)
  95. assert isinstance(res.parent.parent, GroupResult)
  96. assert isinstance(res.parent.parent.parent, AsyncResult)
  97. assert isinstance(res.parent.parent.parent.parent, AsyncResult)
  98. # first we check the last task
  99. assert_ids(res, 51, expected_root_id, res.parent.id)
  100. # then the chord callback
  101. prev, (root_id, parent_id, value) = res.parent.get(timeout=30)
  102. assert value == 50
  103. assert root_id == expected_root_id
  104. # started by one of the chord header tasks.
  105. assert parent_id in res.parent.parent.results
  106. # check what the chord callback recorded
  107. for i, p in enumerate(prev):
  108. root_id, parent_id, value = p
  109. assert root_id == expected_root_id
  110. assert parent_id == res.parent.parent.parent.id
  111. # ids(i=2)
  112. root_id, parent_id, value = res.parent.parent.parent.get(timeout=30)
  113. assert value == 2
  114. assert parent_id == res.parent.parent.parent.parent.id
  115. assert root_id == expected_root_id
  116. # ids(i=1)
  117. root_id, parent_id, value = res.parent.parent.parent.parent.get(
  118. timeout=30)
  119. assert value == 1
  120. assert root_id == expected_root_id
  121. assert parent_id is None