test_canvas.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. from __future__ import absolute_import, unicode_literals
  2. import pytest
  3. from celery import chain, group, uuid
  4. from cyanide.tasks import add, collect_ids, ids
  5. class test_chain:
  6. def test_simple_chain(self, manager):
  7. c = add.s(4, 4) | add.s(8) | add.s(16)
  8. assert manager.join(c()) == 32
  9. def test_complex_chain(self, manager):
  10. c = (
  11. add.s(2, 2) | (
  12. add.s(4) | add.s(8) | add.s(16)
  13. ) |
  14. group(add.s(i) for i in range(4))
  15. )
  16. res = c()
  17. assert res.get() == [32, 33, 34, 35]
  18. def test_parent_ids(self, manager, num=10):
  19. c = chain(ids.si(i) for i in range(num))
  20. c.freeze()
  21. res = c()
  22. res.get(timeout=5)
  23. self.assert_ids(res, num - 1)
  24. def assert_ids(self, res, size):
  25. i, root = size, res
  26. while root.parent:
  27. root = root.parent
  28. node = res
  29. while node:
  30. root_id, parent_id, value = node.get(timeout=5)
  31. assert value == i
  32. assert root_id == root.id
  33. if node.parent:
  34. assert parent_id == node.parent.id
  35. node = node.parent
  36. i -= 1
  37. class test_group:
  38. def test_parent_ids(self):
  39. g = ids.si(1) | ids.si(2) | group(ids.si(i) for i in range(2, 50))
  40. res = g()
  41. expected_root_id = res.parent.parent.id
  42. expected_parent_id = res.parent.id
  43. values = res.get(timeout=5)
  44. for i, r in enumerate(values):
  45. root_id, parent_id, value = r
  46. assert root_id == expected_root_id
  47. assert parent_id == expected_parent_id
  48. assert value == i + 2
  49. class xxx_chord:
  50. @pytest.mark.celery(redis_results=1)
  51. def test_parent_ids(self, manager):
  52. self.assert_parentids_chord()
  53. self.assert_parentids_chord(uuid(), uuid())
  54. def assert_parentids_chord(self, base_root=None, base_parent=None):
  55. g = (
  56. ids.si(1) |
  57. ids.si(2) |
  58. group(ids.si(i) for i in range(3, 50)) |
  59. collect_ids.s(i=50) |
  60. ids.si(51)
  61. )
  62. g.freeze(root_id=base_root, parent_id=base_parent)
  63. res = g.apply_async(root_id=base_root, parent_id=base_parent)
  64. expected_root_id = base_root or res.parent.parent.parent.id
  65. root_id, parent_id, value = res.get(timeout=5)
  66. assert value == 51
  67. assert root_id == expected_root_id
  68. assert parent_id == res.parent.id
  69. prev, (root_id, parent_id, value) = res.parent.get(timeout=5)
  70. assert value == 50
  71. assert root_id == expected_root_id
  72. assert parent_id == res.parent.parent.id
  73. for i, p in enumerate(prev):
  74. root_id, parent_id, value = p
  75. assert root_id == expected_root_id
  76. assert parent_id == res.parent.parent.id
  77. root_id, parent_id, value = res.parent.parent.get(timeout=5)
  78. assert value == 2
  79. assert parent_id == res.parent.parent.parent.id
  80. assert root_id == expected_root_id
  81. root_id, parent_id, value = res.parent.parent.parent.get(timeout=5)
  82. assert value == 1
  83. assert root_id == expected_root_id
  84. assert parent_id == base_parent