ソースを参照

Tests passing

Ask Solem 13 年 前
コミット
3b1e120f5d

+ 1 - 1
celery/backends/redis.py

@@ -73,7 +73,7 @@ class RedisBackend(KeyValueStoreBackend):
         self.client.delete(key)
 
     def on_chord_apply(self, setid, body, result=None, **kwargs):
-        self.app.TaskSetResult(setid, r).save()
+        self.app.TaskSetResult(setid, result).save()
         pass
 
     def on_chord_part_return(self, task, propagate=False,

+ 3 - 1
celery/tests/test_backends/test_base.py

@@ -7,6 +7,7 @@ import types
 from mock import Mock
 from nose import SkipTest
 
+from celery.result import AsyncResult
 from celery.utils import serialization
 from celery.utils.serialization import subclass_exception
 from celery.utils.serialization import \
@@ -100,7 +101,8 @@ class test_BaseBackend_interface(unittest.TestCase):
         from celery.registry import tasks
         p, tasks[unlock] = tasks.get(unlock), Mock()
         try:
-            b.on_chord_apply("dakj221", "sdokqweok")
+            b.on_chord_apply("dakj221", "sdokqweok",
+                             result=map(AsyncResult, [1, 2, 3]))
             self.assertTrue(tasks[unlock].apply_async.call_count)
         finally:
             tasks[unlock] = p

+ 3 - 1
celery/tests/test_backends/test_redis_unit.py

@@ -6,6 +6,7 @@ from mock import Mock, patch
 
 from celery import current_app
 from celery import states
+from celery.result import AsyncResult
 from celery.registry import tasks
 from celery.task import subtask
 from celery.utils import cached_property, uuid
@@ -93,7 +94,8 @@ class test_RedisBackend(unittest.TestCase):
         self.assertEqual(b.expires, 60)
 
     def test_on_chord_apply(self):
-        self.Backend().on_chord_apply("setid")
+        self.Backend().on_chord_apply("setid", {},
+                                      result=map(AsyncResult, [1, 2, 3]))
 
     def test_mget(self):
         b = self.MockBackend()

+ 43 - 17
celery/tests/test_task/test_chord.py

@@ -3,6 +3,7 @@ from __future__ import absolute_import
 from mock import patch
 
 from celery import current_app
+from celery.result import AsyncResult
 from celery.task import chords
 from celery.task import TaskSet
 from celery.tests.utils import AppCase, Mock
@@ -15,36 +16,61 @@ def add(x, y):
     return x + y
 
 
+@current_app.task
+def callback(r):
+    return r
+
+
+class TSR(chords.TaskSetResult):
+    is_ready = True
+    value = [2, 4, 8, 6]
+
+    def ready(self):
+        return self.is_ready
+
+    def join(self, **kwargs):
+        return self.value
+
+    def join_native(self, **kwargs):
+        return self.value
+
 class test_unlock_chord_task(AppCase):
 
-    @patch("celery.task.chords.TaskSetResult")
     @patch("celery.task.chords._unlock_chord.retry")
-    def test_unlock_ready(self, retry, TaskSetResult):
-        callback = Mock()
-        result = Mock(attrs=dict(ready=lambda: True,
-                                 join=lambda **kw: [2, 4, 8, 6]))
-        TaskSetResult.restore = lambda setid: result
+    def test_unlock_ready(self, retry):
+        callback.apply_async = Mock()
+
+        pts, chords.TaskSetResult = chords.TaskSetResult, TSR
         subtask, chords.subtask = chords.subtask, passthru
         try:
-            chords._unlock_chord("setid", callback)
+            chords._unlock_chord("setid", callback.subtask(),
+                    result=map(AsyncResult, [1, 2, 3]))
         finally:
             chords.subtask = subtask
-        callback.delay.assert_called_with([2, 4, 8, 6])
-        result.delete.assert_called_with()
+            chords.TaskSetResult = pts
+        callback.apply_async.assert_called_with(([2, 4, 8, 6], ), {})
         # did not retry
         self.assertFalse(retry.call_count)
 
     @patch("celery.task.chords.TaskSetResult")
     @patch("celery.task.chords._unlock_chord.retry")
     def test_when_not_ready(self, retry, TaskSetResult):
-        callback = Mock()
-        result = Mock(attrs=dict(ready=lambda: False))
-        TaskSetResult.restore = lambda setid: result
-        chords._unlock_chord("setid", callback, interval=10, max_retries=30)
-        self.assertFalse(callback.delay.call_count)
-        # did retry
-        chords._unlock_chord.retry.assert_called_with(countdown=10,
-                                                     max_retries=30)
+        callback.apply_async = Mock()
+
+        class NeverReady(TSR):
+            is_ready = False
+
+        pts, chords.TaskSetResult = chords.TaskSetResult, NeverReady
+        try:
+            chords._unlock_chord("setid", callback.subtask, interval=10,
+                                max_retries=30,
+                                result=map(AsyncResult, [1, 2, 3]))
+            self.assertFalse(callback.apply_async.call_count)
+            # did retry
+            chords._unlock_chord.retry.assert_called_with(countdown=10,
+                                                          max_retries=30)
+        finally:
+            chords.TaskSetResult = pts
 
     def test_is_in_registry(self):
         from celery.registry import tasks