Browse Source

Add regression test for chain duplication in chords (#3771)

georgepsarakis 8 years ago
parent
commit
38943e6000
2 changed files with 30 additions and 1 deletions
  1. 9 0
      t/integration/tasks.py
  2. 21 1
      t/integration/test_canvas.py

+ 9 - 0
t/integration/tasks.py

@@ -63,3 +63,12 @@ def retry_once(self):
     if self.request.retries:
     if self.request.retries:
         return self.request.retries
         return self.request.retries
     raise self.retry(countdown=0.1)
     raise self.retry(countdown=0.1)
+
+
+@shared_task
+def redis_echo(message):
+    """Task that appends the message to a redis list"""
+    from redis import StrictRedis
+
+    redis_connection = StrictRedis()
+    redis_connection.rpush('redis-echo', message)

+ 21 - 1
t/integration/test_canvas.py

@@ -1,10 +1,11 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 import pytest
 import pytest
+from redis import StrictRedis
 from celery import chain, chord, group
 from celery import chain, chord, group
 from celery.exceptions import TimeoutError
 from celery.exceptions import TimeoutError
 from celery.result import AsyncResult, GroupResult
 from celery.result import AsyncResult, GroupResult
 from .conftest import flaky
 from .conftest import flaky
-from .tasks import add, add_replaced, add_to_all, collect_ids, ids
+from .tasks import add, add_replaced, add_to_all, collect_ids, ids, redis_echo
 
 
 TIMEOUT = 120
 TIMEOUT = 120
 
 
@@ -27,6 +28,25 @@ class test_chain:
         res = c()
         res = c()
         assert res.get(timeout=TIMEOUT) == [64, 65, 66, 67]
         assert res.get(timeout=TIMEOUT) == [64, 65, 66, 67]
 
 
+    @flaky
+    def test_group_chord_group_chain(self, manager):
+        if not manager.app.conf.result_backend.startswith('redis'):
+            raise pytest.skip('Requires redis result backend.')
+        redis_connection = StrictRedis()
+        redis_connection.delete('redis-echo')
+        before = group(redis_echo.si('before {}'.format(i)) for i in range(3))
+        connect = redis_echo.si('connect')
+        after = group(redis_echo.si('after {}'.format(i)) for i in range(2))
+
+        result = (before | connect | after).delay()
+        result.get(timeout=TIMEOUT)
+        redis_messages = redis_connection.lrange('redis-echo', 0, -1)
+        assert set(['before 0', 'before 1', 'before 2']) == \
+            set(redis_messages[:3])
+        assert redis_messages[3] == 'connect'
+        assert set(redis_messages[4:]) == set(['after 0', 'after 1'])
+        redis_connection.delete('redis-echo')
+
     @flaky
     @flaky
     def test_parent_ids(self, manager, num=10):
     def test_parent_ids(self, manager, num=10):
         assert manager.inspect().ping()
         assert manager.inspect().ping()