Browse Source

Added TaskSetResult.iter_native: Native way to iterate on multiple results as they finish

Ask Solem 14 years ago
parent
commit
a083980ab6
2 changed files with 10 additions and 7 deletions
  1. 3 5
      celery/backends/amqp.py
  2. 7 2
      celery/result.py

+ 3 - 5
celery/backends/amqp.py

@@ -207,7 +207,6 @@ class AMQPBackend(BaseDictBackend):
             consumer = self._create_consumer(bindings, channel)
             consumer.consume()
             ids = set(task_ids)
-            results = {}
             cached_ids = set()
             for task_id in ids:
                 try:
@@ -216,22 +215,21 @@ class AMQPBackend(BaseDictBackend):
                     pass
                 else:
                     if cached["status"] in states.READY_STATES:
-                        results[task_id] = cached
+                        yield task_id, cached
                         cached_ids.add(task_id)
             ids ^= cached_ids
             try:
                 while ids:
                     r = self.drain_events(consumer, timeout=timeout)
-                    results.update(r)
                     ids ^= set(r.keys())
+                    for ready_id, ready_meta in r.items():
+                        yield ready_id, ready_meta
             finally:
                 consumer.cancel()
         finally:
             channel.close()
             conn.release()
 
-    return results
-
     def close(self):
         if self._pool is not None:
             self._pool.close()

+ 7 - 2
celery/result.py

@@ -355,6 +355,11 @@ class TaskSetResult(object):
                         time.time() >= time_start + timeout):
                     raise TimeoutError("join operation timed out.")
 
+    def iter_native(self, timeout=None):
+        backend = self.subtasks[0].backend
+        ids = [subtask.task_id for subtask in self.subtasks]
+        return backend.get_many(ids, timeout=timeout)
+
     def join_native(self, timeout=None, propagate=True):
         """Backend optimized version of :meth:`join`.
 
@@ -368,9 +373,9 @@ class TaskSetResult(object):
         """
         backend = self.subtasks[0].backend
         results = PositionQueue(length=self.total)
-        ids = [subtask.task_id for subtask in self.subtasks]
 
-        states = backend.get_many(ids, timeout=timeout)
+        ids = [subtask.task_id for subtask in self.subtasks]
+        states = dict(backend.get_many(ids, timeout=timeout))
 
         for task_id, meta in states.items():
             index = self.subtasks.index(task_id)