浏览代码

Fixed *potential* infinite loop in BaseAsyncResult.__eq__, although no evidence that it has ever been triggered.

Ask Solem 15 年之前
父节点
当前提交
b87d81664d
共有 1 个文件被更改,包括 21 次插入11 次删除
  1. 21 11
      celery/result.py

+ 21 - 11
celery/result.py

@@ -6,6 +6,7 @@ Asynchronous result types.
 """
 import time
 from itertools import imap
+from copy import copy
 
 from celery import states
 from celery.utils import any, all
@@ -98,7 +99,10 @@ class BaseAsyncResult(object):
     def __eq__(self, other):
         if isinstance(other, self.__class__):
             return self.task_id == other.task_id
-        return self == other
+        return other == self.task_id
+
+    def __copy__(self):
+        return self.__class__(self.task_id, backend=self.backend)
 
     @property
     def result(self):
@@ -159,8 +163,9 @@ class AsyncResult(BaseAsyncResult):
 
     """
 
-    def __init__(self, task_id):
-        super(AsyncResult, self).__init__(task_id, backend=default_backend)
+    def __init__(self, task_id, backend=None):
+        backend = backend or default_backend
+        super(AsyncResult, self).__init__(task_id, backend)
 
 
 class TaskSetResult(object):
@@ -262,15 +267,20 @@ class TaskSetResult(object):
         :raises: The exception if any of the tasks raised an exception.
 
         """
-        results = dict((subtask.task_id, subtask.__class__(subtask.task_id))
+        pending = list(self.subtasks)
+        results = dict((subtask.task_id, copy(subtask))
                             for subtask in self.subtasks)
-        while results:
-            for task_id, pending_result in results.items():
-                if pending_result.status == states.SUCCESS:
-                    results.pop(task_id, None)
-                    yield pending_result.result
-                elif pending_result.status == states.FAILURE:
-                    raise pending_result.result
+        while pending:
+            for task_id in pending:
+                result = results[task_id]
+                if result.status == states.SUCCESS:
+                    try:
+                        pending.remove(task_id)
+                    except ValueError:
+                        pass
+                    yield result.result
+                elif result.status == states.FAILURE:
+                    raise result.result
 
     def join(self, timeout=None):
         """Gather the results for all of the tasks in the taskset,