Kaynağa Gözat

Proper implementation of TaskSet.iterate

Ask Solem 14 yıl önce
ebeveyn
işleme
b8a293d40f
2 değiştirilmiş dosya ile 23 ekleme ve 13 silme
  1. 15 13
      celery/result.py
  2. 8 0
      celery/tests/test_task/test_result.py

+ 15 - 13
celery/result.py

@@ -323,27 +323,29 @@ class ResultSet(object):
         """`res[i] -> res.results[i]`"""
         """`res[i] -> res.results[i]`"""
         return self.results[index]
         return self.results[index]
 
 
-    def iterate(self):
+    def iterate(self, timeout=None, propagate=True, interval=0.5):
         """Iterate over the return values of the tasks as they finish
         """Iterate over the return values of the tasks as they finish
         one by one.
         one by one.
 
 
         :raises: The exception if any of the tasks raised an exception.
         :raises: The exception if any of the tasks raised an exception.
 
 
         """
         """
-        pending = list(self.results)
+        elapsed = 0.0
         results = dict((result.task_id, copy(result))
         results = dict((result.task_id, copy(result))
                             for result in self.results)
                             for result in self.results)
-        while pending:
-            for task_id in pending:
-                result = results[task_id]
-                if result.status == states.SUCCESS:
-                    try:
-                        pending.remove(task_id)
-                    except ValueError:
-                        pass
-                    yield result.result
-                elif result.status in states.PROPAGATE_STATES:
-                    raise result.result
+
+        while results:
+            removed = set()
+            for task_id, result in results.iteritems():
+                yield result.get(timeout=timeout and timeout - elapsed,
+                                 propagate=propagate, interval=0.0)
+                removed.add(task_id)
+            for task_id in removed:
+                results.pop(task_id, None)
+            time.sleep(interval)
+            elapsed += interval
+            if timeout and elapsed >= timeout:
+                raise TimeoutError("The operation timed out")
 
 
     def join(self, timeout=None, propagate=True, interval=0.5):
     def join(self, timeout=None, propagate=True, interval=0.5):
         """Gathers the results of all tasks as a list in order.
         """Gathers the results of all tasks as a list in order.

+ 8 - 0
celery/tests/test_task/test_result.py

@@ -148,6 +148,11 @@ class MockAsyncResultFailure(AsyncResult):
     def status(self):
     def status(self):
         return states.FAILURE
         return states.FAILURE
 
 
+    def get(self, propagate=True, **kwargs):
+        if propagate:
+            raise self.result
+        return self.result
+
 
 
 class MockAsyncResultSuccess(AsyncResult):
 class MockAsyncResultSuccess(AsyncResult):
     forgotten = False
     forgotten = False
@@ -163,6 +168,9 @@ class MockAsyncResultSuccess(AsyncResult):
     def status(self):
     def status(self):
         return states.SUCCESS
         return states.SUCCESS
 
 
+    def get(self, **kwargs):
+        return self.result
+
 
 
 class SimpleBackend(object):
 class SimpleBackend(object):
         ids = []
         ids = []