test_canvas.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. from __future__ import absolute_import, unicode_literals
  2. import pytest
  3. from celery import chain, group, uuid
  4. from celery.exceptions import TimeoutError
  5. from .tasks import add, collect_ids, ids
  6. TIMEOUT = 120
  7. class test_chain:
  8. def test_simple_chain(self, manager):
  9. c = add.s(4, 4) | add.s(8) | add.s(16)
  10. assert c().get(timeout=TIMEOUT) == 32
  11. def test_complex_chain(self, manager):
  12. c = (
  13. add.s(2, 2) | (
  14. add.s(4) | add.s(8) | add.s(16)
  15. ) |
  16. group(add.s(i) for i in range(4))
  17. )
  18. res = c()
  19. assert res.get(timeout=TIMEOUT) == [32, 33, 34, 35]
  20. def test_parent_ids(self, manager, num=10):
  21. assert manager.inspect().ping()
  22. c = chain(ids.si(i) for i in range(num))
  23. c.freeze()
  24. res = c()
  25. try:
  26. res.get(timeout=TIMEOUT)
  27. except TimeoutError:
  28. print(manager.inspect.active())
  29. print(manager.inspect.reserved())
  30. print(manager.inspect.stats())
  31. raise
  32. self.assert_ids(res, num - 1)
  33. def assert_ids(self, res, size):
  34. i, root = size, res
  35. while root.parent:
  36. root = root.parent
  37. node = res
  38. while node:
  39. root_id, parent_id, value = node.get(timeout=30)
  40. assert value == i
  41. assert root_id == root.id
  42. if node.parent:
  43. assert parent_id == node.parent.id
  44. node = node.parent
  45. i -= 1
  46. class test_group:
  47. def test_parent_ids(self, manager):
  48. assert manager.inspect().ping()
  49. g = ids.si(1) | ids.si(2) | group(ids.si(i) for i in range(2, 50))
  50. res = g()
  51. expected_root_id = res.parent.parent.id
  52. expected_parent_id = res.parent.id
  53. values = res.get(timeout=TIMEOUT)
  54. for i, r in enumerate(values):
  55. root_id, parent_id, value = r
  56. assert root_id == expected_root_id
  57. assert parent_id == expected_parent_id
  58. assert value == i + 2
  59. @pytest.mark.celery(result_backend='redis://')
  60. class xxx_chord:
  61. def test_parent_ids(self, manager):
  62. self.assert_parentids_chord()
  63. def test_parent_ids__already_set(self, manager):
  64. self.assert_parentids_chord(uuid(), uuid())
  65. def assert_parentids_chord(self, base_root=None, base_parent=None):
  66. g = (
  67. ids.si(1) |
  68. ids.si(2) |
  69. group(ids.si(i) for i in range(3, 50)) |
  70. collect_ids.s(i=50) |
  71. ids.si(51)
  72. )
  73. g.freeze(root_id=base_root, parent_id=base_parent)
  74. res = g.apply_async(root_id=base_root, parent_id=base_parent)
  75. expected_root_id = base_root or res.parent.parent.parent.id
  76. root_id, parent_id, value = res.get(timeout=30)
  77. assert value == 51
  78. assert root_id == expected_root_id
  79. assert parent_id == res.parent.id
  80. prev, (root_id, parent_id, value) = res.parent.get(timeout=30)
  81. assert value == 50
  82. assert root_id == expected_root_id
  83. assert parent_id == res.parent.parent.id
  84. for i, p in enumerate(prev):
  85. root_id, parent_id, value = p
  86. assert root_id == expected_root_id
  87. assert parent_id == res.parent.parent.id
  88. root_id, parent_id, value = res.parent.parent.get(timeout=30)
  89. assert value == 2
  90. assert parent_id == res.parent.parent.parent.id
  91. assert root_id == expected_root_id
  92. root_id, parent_id, value = res.parent.parent.parent.get(timeout=30)
  93. assert value == 1
  94. assert root_id == expected_root_id
  95. assert parent_id == base_parent