Parcourir la source

TaskSet.run() now returns a celery.result.TaskSetResult instance, which lets
you inspect the status and return values of a taskset as it was a single
entity.

Ask Solem il y a 16 ans
Parent
commit
61cbba5345
2 fichiers modifiés avec 159 ajouts et 57 suppressions
  1. 126 0
      celery/result.py
  2. 33 57
      celery/task.py

+ 126 - 0
celery/result.py

@@ -4,6 +4,9 @@ Asynchronous result types.
 
 """
 from celery.backends import default_backend
+from celery.datastructures import PositionQueue
+from celery.timer import TimeoutTimer
+from itertools import imap
 
 
 class BaseAsyncResult(object):
@@ -136,3 +139,126 @@ class AsyncResult(BaseAsyncResult):
 
     def __init__(self, task_id):
         super(AsyncResult, self).__init__(task_id, backend=default_backend)
+
+
+class TaskSetResult(object):
+    """Working with :class:`celery.task.TaskSet` results.
+
+    An instance of this class is returned by :meth:`celery.task.TaskSet.run().
+    It lets you inspect the status and return values of a taskset as a
+    single entity.
+
+    :option taskset_id: see :attr:`taskset_id`.
+    :option subtask_ids: see :attr:`subtask_ids`.
+
+    .. attribute:: taskset_id
+
+        The UUID of the taskset itself.
+
+    .. attribute:: subtask_ids
+
+        The list of task UUID's for all of the subtasks.
+
+    .. attribute:: subtasks
+
+        A list of :class:`AsyncResult`` instances for all of the subtasks.
+
+    """
+    def __init__(self, taskset_id, subtask_ids):
+        self.taskset_id = taskset_id
+        self.subtask_ids = subtask_ids
+        self.subtasks = map(AsyncResult, self.subtask_ids)
+
+    def itersubtasks(self):
+        """:returns: an iterator for iterating over the tasksets
+        :class:`AsyncResult` objects."""
+        return (subtask for subtask in self.subtasks)
+
+    def successful(self):
+        """:returns: ``True`` if all of the tasks in the taskset finished
+        successfully (i.e. did not raise an exception)."""
+        return all((subtask.successful()
+                        for subtask in self.itersubtasks()))
+
+    def failed(self):
+        """:returns: ``True`` if any of the tasks in the taskset failed.
+        (i.e., raised an exception)"""
+        return any((not subtask.successful()
+                        for subtask in self.itersubtasks()))
+
+    def waiting(self):
+        """:returns: ``True`` if any of the tasks in the taskset is still
+        waiting for execution."""
+        return any((not subtask.ready()
+                        for subtask in self.itersubtasks()))
+
+    def ready(self):
+        """:returns: ``True`` if all of the tasks in the taskset has been
+        executed."""
+        return all((subtask.ready()
+                        for subtask in self.itersubtasks()))
+
+    def completed_count(self):
+        """:returns: the number of tasks completed."""
+        return sum(imap(int, (subtask.successful()
+                                for subtask in self.itersubtasks())))
+
+    def __iter__(self):
+        """``iter(res)`` -> ``res.iterate()``."""
+        return self.iterate()
+
+    def iterate(self):
+        """Iterate over the return values of the tasks as they finish
+        one by one.
+        
+        :raises: The exception if any of the tasks raised an exception.
+
+        """
+        results = dict([(task_id, AsyncResult(task_id))
+                            for task_id in self.subtask_ids])
+        while results:
+            for task_id, pending_result in results.items():
+                if pending_result.status == "DONE":
+                    del(results[task_id])
+                    yield pending_result.result
+                elif pending_result.status == "FAILURE":
+                    raise pending_result.result
+
+    def join(self, timeout=None):
+        """Gather the results for all of the tasks in the taskset,
+        and return a list with them ordered by the order of which they
+        were called.
+
+        :keyword timeout: The time in seconds, how long
+            it will wait for results, before the operation times out.
+
+        :raises celery.timer.TimeoutError: if ``timeout`` is not ``None``
+            and the operation takes longer than ``timeout`` seconds.
+
+        If any of the tasks raises an exception, the exception
+        will be reraised by :meth:`join`.
+
+        :returns: list of return values for all tasks in the taskset.
+
+        """
+        timeout_timer = TimeoutTimer(timeout) # Timeout timer starts here.
+        results = PositionQueue(length=self.total)
+
+        while True:
+            for position, pending_result in enumerate(self.subtasks):
+                if pending_result.status == "DONE":
+                    results[position] = pending_result.result
+                elif pending_result.status == "FAILURE":
+                    raise pending_result.result
+            if results.full():
+                # Make list copy, so the returned type is not a position
+                # queue.
+                return list(results)
+
+            # This raises TimeoutError when timed out.
+            timeout_timer.tick()
+
+    @property
+    def total(self):
+        """The total number of tasks in the :class:`celery.task.TaskSet`."""
+        return len(self.subtasks)

+ 33 - 57
celery/task.py

@@ -4,14 +4,12 @@ Working with tasks and task sets.
 
 """
 from carrot.connection import DjangoAMQPConnection
+from celery.messaging import TaskPublisher, TaskConsumer
 from celery.log import setup_logger
 from celery.registry import tasks
-from celery.messaging import TaskPublisher, TaskConsumer
 from datetime import timedelta
 from celery.backends import default_backend
-from celery.datastructures import PositionQueue
-from celery.result import AsyncResult
-from celery.timer import TimeoutTimer
+from celery.result import AsyncResult, TaskSetResult
 import uuid
 import pickle
 
@@ -54,10 +52,10 @@ def apply_async(task, args=None, kwargs=None, routing_key=None,
     for option_name, option_value in message_opts.items():
         message_opts[option_name] = getattr(task, option_name, option_value)
 
-    amqp_connection = DjangoAMQPConnection(connect_timeout=connect_timeout)
-    publisher = TaskPublisher(connection=amqp_connection)
-    task_id = publisher.delay_task(task.name, args, kwargs, **message_opts)
-    amqp_connection.close()
+    with DjangoAMQPConnection(connect_timeout=connect_timeout) as conn:
+        with TaskPublisher(connection=conn) as publisher:
+            task_id = publisher.delay_task(task.name, args, kwargs,
+                                           **message_opts)
     return AsyncResult(task_id)
 
 
@@ -310,7 +308,7 @@ class TaskSet(object):
         ...                 [], {"feed_url": "http://bbc.com/rss"},
         ...                 [], {"feed_url": "http://xkcd.com/rss"}])
 
-        >>> taskset_id, subtask_ids = taskset.run()
+        >>> taskset_result = taskset.run()
         >>> list_of_return_values = taskset.join()
 
 
@@ -329,39 +327,43 @@ class TaskSet(object):
     def run(self):
         """Run all tasks in the taskset.
 
-        :returns: A tuple containing the taskset id, and a list
-            of subtask ids.
-
-        :rtype: tuple
+        :returns: A :class:`celery.result.TaskSetResult` instance.
 
         Example
 
-            >>> ts = RefreshFeeds([
+            >>> ts = TaskSet(RefreshFeedTask, [
             ...         ["http://foo.com/rss", {}],
             ...         ["http://bar.com/rss", {}],
             ... )
-            >>> taskset_id, subtask_ids = ts.run()
-            >>> taskset_id
+            >>> result = ts.run()
+            >>> result.taskset_id
             "d2c9b261-8eff-4bfb-8459-1e1b72063514"
-            >>> subtask_ids
+            >>> result.subtask_ids
             ["b4996460-d959-49c8-aeb9-39c530dcde25",
             "598d2d18-ab86-45ca-8b4f-0779f5d6a3cb"]
+            >>> result.waiting()
+            True
             >>> time.sleep(10)
-            >>> is_done(taskset_id)
+            >>> result.ready()
+            True
+            >>> result.successful()
             True
+            >>> result.failed()
+            False
+            >>> result.join()
+            [True, True]
+
         """
         taskset_id = str(uuid.uuid4())
-        amqp_connection = DjangoAMQPConnection()
-        publisher = TaskPublisher(connection=amqp_connection)
-        subtask_ids = []
-        for arg, kwarg in self.arguments:
-            subtask_id = publisher.delay_task_in_set(task_name=self.task_name,
-                                                     taskset_id=taskset_id,
-                                                     task_args=arg,
-                                                     task_kwargs=kwarg)
-            subtask_ids.append(subtask_id)
-        amqp_connection.close()
-        return taskset_id, subtask_ids
+        with DjangoAMQPConnection() as amqp_connection:
+            with TaskPublisher(connection=amqp_connection) as publisher:
+                subtask_ids = [publisher.delay_task_in_set(
+                                                    task_name=self.task_name,
+                                                    taskset_id=taskset_id,
+                                                    task_args=arg,
+                                                    task_kwargs=kwarg)
+                                for arg, kwarg in self.arguments]
+        return TaskSetResult(taskset_id, subtask_ids)
 
     def iterate(self):
         """Iterate over the results returned after calling :meth:`run`.
@@ -370,16 +372,7 @@ class TaskSet(object):
         be re-raised.
 
         """
-        taskset_id, subtask_ids = self.run()
-        results = dict([(task_id, AsyncResult(task_id))
-                            for task_id in subtask_ids])
-        while results:
-            for task_id, pending_result in results.items():
-                if pending_result.status == "DONE":
-                    del(results[task_id])
-                    yield pending_result.result
-                elif pending_result.status == "FAILURE":
-                    raise pending_result.result
+        return iter(self.run())
 
     def join(self, timeout=None):
         """Gather the results for all of the tasks in the taskset,
@@ -398,24 +391,7 @@ class TaskSet(object):
         :returns: list of return values for all tasks in the taskset.
 
         """
-        timeout_timer = TimeoutTimer(timeout) # Timeout timer starts here.
-        taskset_id, subtask_ids = self.run()
-        pending_results = map(AsyncResult, subtask_ids)
-        results = PositionQueue(length=len(subtask_ids))
-
-        while True:
-            for position, pending_result in enumerate(pending_results):
-                if pending_result.status == "DONE":
-                    results[position] = pending_result.result
-                elif pending_result.status == "FAILURE":
-                    raise pending_result.result
-            if results.full():
-                # Make list copy, so the returned type is not a position
-                # queue.
-                return list(results)
-
-            # This raises TimeoutError when timed out.
-            timeout_timer.tick()
+        return self.run().join(timeout=timeout)
 
     @classmethod
     def remote_execute(cls, func, args):