test_canvas.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. from __future__ import absolute_import, unicode_literals
  2. import pytest
  3. from celery import chain, chord, group
  4. from celery.exceptions import TimeoutError
  5. from celery.result import AsyncResult, GroupResult
  6. from .tasks import add, collect_ids, ids
  7. TIMEOUT = 120
  8. class test_chain:
  9. def test_simple_chain(self, manager):
  10. c = add.s(4, 4) | add.s(8) | add.s(16)
  11. assert c().get(timeout=TIMEOUT) == 32
  12. def test_complex_chain(self, manager):
  13. c = (
  14. add.s(2, 2) | (
  15. add.s(4) | add.s(8) | add.s(16)
  16. ) |
  17. group(add.s(i) for i in range(4))
  18. )
  19. res = c()
  20. assert res.get(timeout=TIMEOUT) == [32, 33, 34, 35]
  21. def test_parent_ids(self, manager, num=10):
  22. assert manager.inspect().ping()
  23. c = chain(ids.si(i=i) for i in range(num))
  24. c.freeze()
  25. res = c()
  26. try:
  27. res.get(timeout=TIMEOUT)
  28. except TimeoutError:
  29. print(manager.inspect.active())
  30. print(manager.inspect.reserved())
  31. print(manager.inspect.stats())
  32. raise
  33. self.assert_ids(res, num - 1)
  34. def assert_ids(self, res, size):
  35. i, root = size, res
  36. while root.parent:
  37. root = root.parent
  38. node = res
  39. while node:
  40. root_id, parent_id, value = node.get(timeout=30)
  41. assert value == i
  42. if node.parent:
  43. assert parent_id == node.parent.id
  44. assert root_id == root.id
  45. node = node.parent
  46. i -= 1
  47. class test_group:
  48. def test_parent_ids(self, manager):
  49. assert manager.inspect().ping()
  50. g = (
  51. ids.si(i=1) |
  52. ids.si(i=2) |
  53. group(ids.si(i=i) for i in range(2, 50))
  54. )
  55. res = g()
  56. expected_root_id = res.parent.parent.id
  57. expected_parent_id = res.parent.id
  58. values = res.get(timeout=TIMEOUT)
  59. for i, r in enumerate(values):
  60. root_id, parent_id, value = r
  61. assert root_id == expected_root_id
  62. assert parent_id == expected_parent_id
  63. assert value == i + 2
  64. def assert_ids(r, expected_value, expected_root_id, expected_parent_id):
  65. root_id, parent_id, value = r.get(timeout=TIMEOUT)
  66. assert expected_value == value
  67. assert root_id == expected_root_id
  68. assert parent_id == expected_parent_id
  69. @pytest.mark.celery(result_backend='redis://')
  70. class test_chord:
  71. def test_parent_ids(self, manager):
  72. root = ids.si(i=1)
  73. expected_root_id = root.freeze().id
  74. g = chain(
  75. root, ids.si(i=2),
  76. chord(
  77. group(ids.si(i=i) for i in range(3, 50)),
  78. chain(collect_ids.s(i=50) | ids.si(i=51)),
  79. ),
  80. )
  81. self.assert_parentids_chord(g(), expected_root_id)
  82. def test_parent_ids__OR(self, manager):
  83. root = ids.si(i=1)
  84. expected_root_id = root.freeze().id
  85. g = (
  86. root |
  87. ids.si(i=2) |
  88. group(ids.si(i=i) for i in range(3, 50)) |
  89. collect_ids.s(i=50) |
  90. ids.si(i=51)
  91. )
  92. self.assert_parentids_chord(g(), expected_root_id)
  93. def assert_parentids_chord(self, res, expected_root_id):
  94. assert isinstance(res, AsyncResult)
  95. assert isinstance(res.parent, AsyncResult)
  96. assert isinstance(res.parent.parent, GroupResult)
  97. assert isinstance(res.parent.parent.parent, AsyncResult)
  98. assert isinstance(res.parent.parent.parent.parent, AsyncResult)
  99. # first we check the last task
  100. assert_ids(res, 51, expected_root_id, res.parent.id)
  101. # then the chord callback
  102. prev, (root_id, parent_id, value) = res.parent.get(timeout=30)
  103. assert value == 50
  104. assert root_id == expected_root_id
  105. # started by one of the chord header tasks.
  106. assert parent_id in res.parent.parent.results
  107. # check what the chord callback recorded
  108. for i, p in enumerate(prev):
  109. root_id, parent_id, value = p
  110. assert root_id == expected_root_id
  111. assert parent_id == res.parent.parent.parent.id
  112. # ids(i=2)
  113. root_id, parent_id, value = res.parent.parent.parent.get(timeout=30)
  114. assert value == 2
  115. assert parent_id == res.parent.parent.parent.parent.id
  116. assert root_id == expected_root_id
  117. # ids(i=1)
  118. root_id, parent_id, value = res.parent.parent.parent.parent.get(
  119. timeout=30)
  120. assert value == 1
  121. assert root_id == expected_root_id
  122. assert parent_id is None