浏览代码

Chord tests passing

Ask Solem 13 年之前
父节点
当前提交
0cabb808b1
共有 4 个文件被更改,包括 31 次插入28 次删除
  1. 2 2
      celery/app/builtins.py
  2. 1 1
      celery/result.py
  3. 5 5
      celery/task/sets.py
  4. 23 20
      celery/tests/test_task/test_chord.py

+ 2 - 2
celery/app/builtins.py

@@ -49,13 +49,13 @@ def add_unlock_chord_task(app):
     It creates a task chain polling the header for completion.
     It creates a task chain polling the header for completion.
 
 
     """
     """
-    from celery.result import AsyncResult, TaskSetResult
     from celery.task.sets import subtask
     from celery.task.sets import subtask
+    from celery import result as _res
 
 
     @app.task(name="celery.chord_unlock", max_retries=None)
     @app.task(name="celery.chord_unlock", max_retries=None)
     def unlock_chord(setid, callback, interval=1, propagate=False,
     def unlock_chord(setid, callback, interval=1, propagate=False,
             max_retries=None, result=None):
             max_retries=None, result=None):
-        result = TaskSetResult(setid, map(AsyncResult, result))
+        result = _res.TaskSetResult(setid, map(_res.AsyncResult, result))
         j = result.join_native if result.supports_native_join else result.join
         j = result.join_native if result.supports_native_join else result.join
         if result.ready():
         if result.ready():
             subtask(callback).delay(j(propagate=propagate))
             subtask(callback).delay(j(propagate=propagate))

+ 1 - 1
celery/result.py

@@ -570,7 +570,7 @@ class TaskSetResult(ResultSet):
         # XXX previously the "results" arg was named "subtasks".
         # XXX previously the "results" arg was named "subtasks".
         if "subtasks" in kwargs:
         if "subtasks" in kwargs:
             results = kwargs["subtasks"]
             results = kwargs["subtasks"]
-        super(TaskSetResult, self).__init__(results, **kwargs)
+        ResultSet.__init__(self, results, **kwargs)
 
 
     def save(self, backend=None):
     def save(self, backend=None):
         """Save taskset result for later retrieval using :meth:`restore`.
         """Save taskset result for later retrieval using :meth:`restore`.

+ 5 - 5
celery/task/sets.py

@@ -49,10 +49,10 @@ class subtask(AttributeDict):
 
 
     def __init__(self, task=None, args=None, kwargs=None, options=None,
     def __init__(self, task=None, args=None, kwargs=None, options=None,
                 type=None, **ex):
                 type=None, **ex):
-        init = super(subtask, self).__init__
+        init = AttributeDict.__init__
 
 
         if isinstance(task, dict):
         if isinstance(task, dict):
-            return init(task)  # works like dict(d)
+            return init(self, task)  # works like dict(d)
 
 
         # Also supports using task class/instance instead of string name.
         # Also supports using task class/instance instead of string name.
         try:
         try:
@@ -64,9 +64,9 @@ class subtask(AttributeDict):
             # will add it to dict(self)
             # will add it to dict(self)
             object.__setattr__(self, "_type", task)
             object.__setattr__(self, "_type", task)
 
 
-        init(task=task_name, args=tuple(args or ()),
-                             kwargs=dict(kwargs or {}, **ex),
-                             options=options or {})
+        init(self, task=task_name, args=tuple(args or ()),
+                                   kwargs=dict(kwargs or {}, **ex),
+                                   options=options or {})
 
 
     def delay(self, *argmerge, **kwmerge):
     def delay(self, *argmerge, **kwmerge):
         """Shortcut to `apply_async(argmerge, kwargs)`."""
         """Shortcut to `apply_async(argmerge, kwargs)`."""

+ 23 - 20
celery/tests/test_task/test_chord.py

@@ -9,6 +9,7 @@ from celery import result
 from celery.result import AsyncResult, TaskSetResult
 from celery.result import AsyncResult, TaskSetResult
 from celery.task import chords
 from celery.task import chords
 from celery.task import task, TaskSet
 from celery.task import task, TaskSet
+from celery.task import sets
 from celery.tests.utils import AppCase, Mock
 from celery.tests.utils import AppCase, Mock
 
 
 passthru = lambda x: x
 passthru = lambda x: x
@@ -26,7 +27,7 @@ def callback(r):
 
 
 class TSR(TaskSetResult):
 class TSR(TaskSetResult):
     is_ready = True
     is_ready = True
-    value = [2, 4, 8, 6]
+    value = None
 
 
     def ready(self):
     def ready(self):
         return self.is_ready
         return self.is_ready
@@ -51,31 +52,26 @@ class test_unlock_chord_task(AppCase):
 
 
     @patch("celery.result.TaskSetResult")
     @patch("celery.result.TaskSetResult")
     def test_unlock_ready(self, TaskSetResult):
     def test_unlock_ready(self, TaskSetResult):
-        from nose import SkipTest
-        raise SkipTest("Not passing")
 
 
-        class NeverReady(TSR):
-            is_ready = False
+        class AlwaysReady(TSR):
+            is_ready = True
+            value = [2, 4, 8, 6]
 
 
         @task
         @task
         def callback(*args, **kwargs):
         def callback(*args, **kwargs):
             pass
             pass
 
 
-        pts, result.TaskSetResult = result.TaskSetResult, NeverReady
+        pts, result.TaskSetResult = result.TaskSetResult, AlwaysReady
         callback.apply_async = Mock()
         callback.apply_async = Mock()
         try:
         try:
             with patch_unlock_retry() as (unlock, retry):
             with patch_unlock_retry() as (unlock, retry):
-                res = Mock(attrs=dict(ready=lambda: True,
-                                        join=lambda **kw: [2, 4, 8, 6]))
-                TaskSetResult.restore = lambda setid: res
-                subtask, chords.subtask = chords.subtask, passthru
+                subtask, sets.subtask = sets.subtask, passthru
                 try:
                 try:
                     unlock("setid", callback,
                     unlock("setid", callback,
                            result=map(AsyncResult, [1, 2, 3]))
                            result=map(AsyncResult, [1, 2, 3]))
                 finally:
                 finally:
                     chords.subtask = subtask
                     chords.subtask = subtask
                 callback.apply_async.assert_called_with(([2, 4, 8, 6], ), {})
                 callback.apply_async.assert_called_with(([2, 4, 8, 6], ), {})
-                result.delete.assert_called_with()
                 # did not retry
                 # did not retry
                 self.assertFalse(retry.call_count)
                 self.assertFalse(retry.call_count)
         finally:
         finally:
@@ -83,16 +79,21 @@ class test_unlock_chord_task(AppCase):
 
 
     @patch("celery.result.TaskSetResult")
     @patch("celery.result.TaskSetResult")
     def test_when_not_ready(self, TaskSetResult):
     def test_when_not_ready(self, TaskSetResult):
-        from nose import SkipTest
-        raise SkipTest("Not passing")
         with patch_unlock_retry() as (unlock, retry):
         with patch_unlock_retry() as (unlock, retry):
-            callback = Mock()
-            result = Mock(attrs=dict(ready=lambda: False))
-            TaskSetResult.restore = lambda setid: result
-            unlock("setid", callback, interval=10, max_retries=30,)
-            self.assertFalse(callback.delay.call_count)
-            # did retry
-            unlock.retry.assert_called_with(countdown=10, max_retries=30)
+
+            class NeverReady(TSR):
+                is_ready = False
+
+            pts, result.TaskSetResult = result.TaskSetResult, NeverReady
+            try:
+                callback = Mock()
+                unlock("setid", callback, interval=10, max_retries=30,
+                            result=map(AsyncResult, [1, 2, 3]))
+                self.assertFalse(callback.delay.call_count)
+                # did retry
+                unlock.retry.assert_called_with(countdown=10, max_retries=30)
+            finally:
+                result.TaskSetResult = pts
 
 
     def test_is_in_registry(self):
     def test_is_in_registry(self):
         self.assertIn("celery.chord_unlock", current_app.tasks)
         self.assertIn("celery.chord_unlock", current_app.tasks)
@@ -116,6 +117,8 @@ class test_Chord_task(AppCase):
 
 
     def test_run(self):
     def test_run(self):
         prev, current_app.backend = current_app.backend, Mock()
         prev, current_app.backend = current_app.backend, Mock()
+        current_app.backend.cleanup = Mock()
+        current_app.backend.cleanup.__name__ = "cleanup"
         try:
         try:
             Chord = current_app.tasks["celery.chord"]
             Chord = current_app.tasks["celery.chord"]