Browse Source

Adds AsyncResult.collect to recursively collect results

Ask Solem 13 years ago
parent
commit
a26737db52
1 changed files with 52 additions and 1 deletions
  1. 52 1
      celery/result.py

+ 52 - 1
celery/result.py

@@ -94,6 +94,53 @@ class AsyncResult(object):
                                                    interval=interval)
     wait = get  # deprecated alias to :meth:`get`.
 
+    def collect(self, timeout=None, propagate=True):
+        """Iterator, like :meth:`get` will wait for the task to complete,
+        but will also follow :class:`AsyncResult` and :class:`ResultSet`
+        returned by the task, yielding for each result in the tree.
+
+        An example would be having the following tasks:
+
+        .. code-block:: python
+
+            @task
+            def A(how_many):
+                return TaskSet(B.subtask((i, )) for i in xrange(how_many))
+
+            @task
+            def B(i):
+                return pow2.delay(i)
+
+            @task
+            def pow2(i):
+                return i ** 2
+
+        Calling :meth:`collect` would return:
+
+        .. code-block:: python
+
+            >>> result = A.delay(10)
+            >>> list(result.collect())
+            [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
+
+        """
+        stack = deque([self])
+
+        native_join = self.supports_native_join
+        popleft = stack.popleft
+        extend = stack.extend
+        append = stack.append
+
+        while stack:
+            res = popleft()
+            if isinstance(res, ResultSet):
+                j = res.join_native if native_join else res.join
+                extend(j(timeout=timeout, propagate=propagate))
+            elif isinstance(res, AsyncResult):
+                append(res.get(timeout=timeout, propagate=propagate))
+            else:
+                yield res
+
     def ready(self):
         """Returns :const:`True` if the task has been executed.
 
@@ -137,6 +184,10 @@ class AsyncResult(object):
             return (self.__class__, (self.task_id, self.backend,
                                      None, self.app))
 
+    @property
+    def supports_native_join(self):
+        return self.backend.supports_native_join
+
     @property
     def result(self):
         """When the task has been executed, this contains the return value.
@@ -428,7 +479,7 @@ class ResultSet(object):
 
     @property
     def supports_native_join(self):
-        return self.results[0].backend.supports_native_join
+        return self.results[0].supports_native_join
 
 
 class TaskSetResult(ResultSet):