test_canvas.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. from __future__ import absolute_import, unicode_literals
  2. import pytest
  3. from celery import chain, chord, group
  4. from celery.exceptions import TimeoutError
  5. from celery.result import AsyncResult, GroupResult
  6. from .tasks import add, collect_ids, ids
  7. TIMEOUT = 120
  8. class test_chain:
  9. def test_simple_chain(self, manager):
  10. c = add.s(4, 4) | add.s(8) | add.s(16)
  11. assert c().get(timeout=TIMEOUT) == 32
  12. def test_complex_chain(self, manager):
  13. c = (
  14. add.s(2, 2) | (
  15. add.s(4) | add.s(8) | add.s(16)
  16. ) |
  17. group(add.s(i) for i in range(4))
  18. )
  19. res = c()
  20. assert res.get(timeout=TIMEOUT) == [32, 33, 34, 35]
  21. def test_parent_ids(self, manager, num=10):
  22. assert manager.inspect().ping()
  23. c = chain(ids.si(i=i) for i in range(num))
  24. c.freeze()
  25. res = c()
  26. try:
  27. res.get(timeout=TIMEOUT)
  28. except TimeoutError:
  29. print(manager.inspect.active())
  30. print(manager.inspect.reserved())
  31. print(manager.inspect.stats())
  32. raise
  33. self.assert_ids(res, num - 1)
  34. def assert_ids(self, res, size):
  35. i, root = size, res
  36. while root.parent:
  37. root = root.parent
  38. node = res
  39. while node:
  40. root_id, parent_id, value = node.get(timeout=30)
  41. assert value == i
  42. if node.parent:
  43. assert parent_id == node.parent.id
  44. assert root_id == root.id
  45. node = node.parent
  46. i -= 1
  47. class test_group:
  48. def test_parent_ids(self, manager):
  49. assert manager.inspect().ping()
  50. g = (
  51. ids.si(i=1) |
  52. ids.si(i=2) |
  53. group(ids.si(i=i) for i in range(2, 50))
  54. )
  55. res = g()
  56. expected_root_id = res.parent.parent.id
  57. expected_parent_id = res.parent.id
  58. values = res.get(timeout=TIMEOUT)
  59. for i, r in enumerate(values):
  60. root_id, parent_id, value = r
  61. assert root_id == expected_root_id
  62. assert parent_id == expected_parent_id
  63. assert value == i + 2
  64. def assert_ids(r, expected_value, expected_root_id, expected_parent_id):
  65. root_id, parent_id, value = r.get(timeout=TIMEOUT)
  66. assert expected_value == value
  67. assert root_id == expected_root_id
  68. assert parent_id == expected_parent_id
  69. class test_chord:
  70. def test_parent_ids(self, manager):
  71. if not manager.app.conf.result_backend.startswith('redis'):
  72. raise pytest.skip('Requires redis result backend.')
  73. root = ids.si(i=1)
  74. expected_root_id = root.freeze().id
  75. g = chain(
  76. root, ids.si(i=2),
  77. chord(
  78. group(ids.si(i=i) for i in range(3, 50)),
  79. chain(collect_ids.s(i=50) | ids.si(i=51)),
  80. ),
  81. )
  82. self.assert_parentids_chord(g(), expected_root_id)
  83. def test_parent_ids__OR(self, manager):
  84. if not manager.app.conf.result_backend.startswith('redis'):
  85. raise pytest.skip('Requires redis result backend.')
  86. root = ids.si(i=1)
  87. expected_root_id = root.freeze().id
  88. g = (
  89. root |
  90. ids.si(i=2) |
  91. group(ids.si(i=i) for i in range(3, 50)) |
  92. collect_ids.s(i=50) |
  93. ids.si(i=51)
  94. )
  95. self.assert_parentids_chord(g(), expected_root_id)
  96. def assert_parentids_chord(self, res, expected_root_id):
  97. assert isinstance(res, AsyncResult)
  98. assert isinstance(res.parent, AsyncResult)
  99. assert isinstance(res.parent.parent, GroupResult)
  100. assert isinstance(res.parent.parent.parent, AsyncResult)
  101. assert isinstance(res.parent.parent.parent.parent, AsyncResult)
  102. # first we check the last task
  103. assert_ids(res, 51, expected_root_id, res.parent.id)
  104. # then the chord callback
  105. prev, (root_id, parent_id, value) = res.parent.get(timeout=30)
  106. assert value == 50
  107. assert root_id == expected_root_id
  108. # started by one of the chord header tasks.
  109. assert parent_id in res.parent.parent.results
  110. # check what the chord callback recorded
  111. for i, p in enumerate(prev):
  112. root_id, parent_id, value = p
  113. assert root_id == expected_root_id
  114. assert parent_id == res.parent.parent.parent.id
  115. # ids(i=2)
  116. root_id, parent_id, value = res.parent.parent.parent.get(timeout=30)
  117. assert value == 2
  118. assert parent_id == res.parent.parent.parent.parent.id
  119. assert root_id == expected_root_id
  120. # ids(i=1)
  121. root_id, parent_id, value = res.parent.parent.parent.parent.get(
  122. timeout=30)
  123. assert value == 1
  124. assert root_id == expected_root_id
  125. assert parent_id is None