Browse Source

Split TaskSetResult in two: ResultSet and TaskSetResult. Thanks to padt

Ask Solem 14 years ago
parent
commit
7a775f4638
1 changed files with 141 additions and 90 deletions
  1. 141 90
      celery/result.py

+ 141 - 90
celery/result.py

@@ -40,13 +40,6 @@ class BaseAsyncResult(object):
         self.task_name = task_name
         self.app = app_or_default(app)
 
-    def __reduce__(self):
-        if self.task_name:
-            return (_unpickle_result, (self.task_id, self.task_name))
-        else:
-            return (self.__class__, (self.task_id, self.backend,
-                                     None, self.app))
-
     def forget(self):
         """Forget about (and possibly remove the result of) this task."""
         self.backend.forget(self.task_id)
@@ -62,11 +55,11 @@ class BaseAsyncResult(object):
                                 connect_timeout=connect_timeout)
 
     def wait(self, timeout=None, propagate=True, interval=0.5):
-        """Wait for task, and return the result.
+        """Wait until task is ready, and return its result.
 
         .. warning::
 
-           Waiting for subtasks may lead to deadlocks.
+           Waiting for tasks within a task may lead to deadlocks.
            Please read :ref:`task-synchronous-subtasks`.
 
         :keyword timeout: How long to wait, in seconds, before the
@@ -107,7 +100,7 @@ class BaseAsyncResult(object):
         return self.status == states.SUCCESS
 
     def failed(self):
-        """Returns :const:`True` if the task failed by exception."""
+        """Returns :const:`True` if the task failed."""
         return self.status == states.FAILURE
 
     def __str__(self):
@@ -129,6 +122,13 @@ class BaseAsyncResult(object):
     def __copy__(self):
         return self.__class__(self.task_id, backend=self.backend)
 
+    def __reduce__(self):
+        if self.task_name:
+            return (_unpickle_result, (self.task_id, self.task_name))
+        else:
+            return (self.__class__, (self.task_id, self.backend,
+                                     None, self.app))
+
     @property
     def result(self):
         """When the task has been executed, this contains the return value.
@@ -146,11 +146,6 @@ class BaseAsyncResult(object):
         """Get the traceback of a failed task."""
         return self.backend.get_traceback(self.task_id)
 
-    @property
-    def status(self):
-        """Deprecated alias of :attr:`state`."""
-        return self.state
-
     @property
     def state(self):
         """The tasks current state.
@@ -171,23 +166,28 @@ class BaseAsyncResult(object):
 
             *FAILURE*
 
-                The task raised an exception, or has been retried more times
-                than its limit. The :attr:`result` attribute contains the
-                exception raised.
+                The task raised an exception, or has exceeded the retry limit.
+                The :attr:`result` attribute then contains the
+                exception raised by the task.
 
             *SUCCESS*
 
                 The task executed successfully. The :attr:`result` attribute
-                contains the resulting value.
+                then contains the tasks return value.
 
         """
         return self.backend.get_status(self.task_id)
 
+    @property
+    def status(self):
+        """Deprecated alias of :attr:`state`."""
+        return self.state
+
 
 class AsyncResult(BaseAsyncResult):
     """Pending task result using the default backend.
 
-    :param task_id: The tasks uuid.
+    :param task_id: The task uuid.
 
     """
 
@@ -201,77 +201,97 @@ class AsyncResult(BaseAsyncResult):
                                           task_name=task_name, app=app)
 
 
-class TaskSetResult(object):
-    """Working with :class:`~celery.task.sets.TaskSet` results.
+class ResultSet(object):
+    """Working with more than one result.
 
-    An instance of this class is returned by
-    `TaskSet`'s :meth:`~celery.task.TaskSet.apply_async()`.  It enables
-    inspection of the subtasks state and return values as a single entity.
-
-    :param taskset_id: The id of the taskset.
-    :param subtasks: List of result instances.
+    :param results: List of result instances.
 
     """
 
-    #: The UUID of the taskset.
-    taskset_id = None
-
-    #: A list of :class:`AsyncResult` instances for all of the subtasks.
-    subtasks = None
+    #: List of results in in the set.
+    results = None
 
-    def __init__(self, taskset_id, subtasks, app=None):
-        self.taskset_id = taskset_id
-        self.subtasks = subtasks
+    def __init__(self, results, app=None, **kwargs):
         self.app = app_or_default(app)
+        self.results = results
 
-    def itersubtasks(self):
-        """Taskset subtask iterator.
+    def add(self, result):
+        """Add :class:`AsyncResult` as a new member of the set.
 
-        :returns: an iterator for iterating over the tasksets
-            :class:`AsyncResult` objects.
+        Does nothing if the result is already a member.
 
         """
-        return (subtask for subtask in self.subtasks)
+        if result not in self.results:
+            self.results.append(result)
+
+    def remove(self, result):
+        """Removes result from the set; it must be a member.
+
+        :raises KeyError: if the result is not a member.
+
+        """
+        if isinstance(result, basestring):
+            result = AsyncResult(result)
+        try:
+            self.results.remove(result)
+        except ValueError:
+            raise KeyError(result)
+
+    def discard(self, result):
+        """Remove result from the set if it is a member.
+
+        If it is not a member, do nothing.
+
+        """
+        try:
+            self.remove(result)
+        except KeyError:
+            pass
+
+    def update(self, results):
+        """Update set with the union of itself and an iterable with
+        results."""
+        self.results.extend(r for r in results if r not in self.results)
+
+    def clear(self):
+        """Remove all results from this set."""
+        self.results[:] = []  # don't create new list.
 
     def successful(self):
-        """Was the taskset successful?
+        """Was all of the tasks successful?
 
-        :returns: :const:`True` if all of the tasks in the taskset finished
+        :returns: :const:`True` if all of the tasks finished
             successfully (i.e. did not raise an exception).
 
         """
-        return all(subtask.successful()
-                        for subtask in self.itersubtasks())
+        return all(result.successful() for result in self.results)
 
     def failed(self):
-        """Did the taskset fail?
+        """Did any of the tasks fail?
 
-        :returns: :const:`True` if any of the tasks in the taskset failed.
+        :returns: :const:`True` if any of the tasks failed.
             (i.e., raised an exception)
 
         """
-        return any(subtask.failed()
-                        for subtask in self.itersubtasks())
+        return any(result.failed() for result in self.results)
 
     def waiting(self):
-        """Is the taskset waiting?
+        """Are any of the tasks incomplete?
 
-        :returns: :const:`True` if any of the tasks in the taskset is still
+        :returns: :const:`True` if any of the tasks is still
             waiting for execution.
 
         """
-        return any(not subtask.ready()
-                        for subtask in self.itersubtasks())
+        return any(not result.ready() for result in self.results)
 
     def ready(self):
-        """Is the task ready?
+        """Did all of the tasks complete? (either by success of failure).
 
-        :returns: :const:`True` if all of the tasks in the taskset has been
+        :returns: :const:`True` if all of the tasks been
             executed.
 
         """
-        return all(subtask.ready()
-                        for subtask in self.itersubtasks())
+        return all(result.ready() for result in self.results)
 
     def completed_count(self):
         """Task completion count.
@@ -279,32 +299,29 @@ class TaskSetResult(object):
         :returns: the number of tasks completed.
 
         """
-        return sum(imap(int, (subtask.successful()
-                                for subtask in self.itersubtasks())))
+        return sum(imap(int, (result.successful() for result in self.results)))
 
     def forget(self):
-        """Forget about (and possible remove the result of) all the tasks
-        in this taskset."""
-        for subtask in self.subtasks:
-            subtask.forget()
+        """Forget about (and possible remove the result of) all the tasks."""
+        for result in self.results:
+            result.forget()
 
     def revoke(self, connection=None, connect_timeout=None):
-        """Revoke all subtasks."""
+        """Revoke all tasks in the set."""
 
         def _do_revoke(connection=None, connect_timeout=None):
-            for subtask in self.subtasks:
-                subtask.revoke(connection=connection)
+            for result in self.results:
+                result.revoke(connection=connection)
 
         return self.app.with_default_connection(_do_revoke)(
                 connection=connection, connect_timeout=connect_timeout)
 
     def __iter__(self):
-        """`iter(res)` -> `res.iterate()`."""
         return self.iterate()
 
     def __getitem__(self, index):
-        """`res[i] -> res.subtasks[i]`"""
-        return self.subtasks[index]
+        """`res[i] -> res.results[i]`"""
+        return self.results[index]
 
     def iterate(self):
         """Iterate over the return values of the tasks as they finish
@@ -313,9 +330,9 @@ class TaskSetResult(object):
         :raises: The exception if any of the tasks raised an exception.
 
         """
-        pending = list(self.subtasks)
-        results = dict((subtask.task_id, copy(subtask))
-                            for subtask in self.subtasks)
+        pending = list(self.results)
+        results = dict((result.task_id, copy(result))
+                            for result in self.results)
         while pending:
             for task_id in pending:
                 result = results[task_id]
@@ -329,8 +346,7 @@ class TaskSetResult(object):
                     raise result.result
 
     def join(self, timeout=None, propagate=True, interval=0.5):
-        """Gathers the results of all tasks in the taskset,
-        and returns a list ordered by the order of the set.
+        """Gathers the results of all tasks as a list in order.
 
         .. note::
 
@@ -342,13 +358,13 @@ class TaskSetResult(object):
 
         .. warning::
 
-            Waiting for subtasks may lead to deadlocks.
+            Waiting for tasks within a task may lead to deadlocks.
             Please see :ref:`task-synchronous-subtasks`.
 
         :keyword timeout: The number of seconds to wait for results before
                           the operation times out.
 
-        :keyword propagate: If any of the subtasks raises an exception, the
+        :keyword propagate: If any of the tasks raises an exception, the
                             exception will be re-raised.
 
         :keyword interval: Time to wait (in seconds) before retrying to
@@ -365,20 +381,20 @@ class TaskSetResult(object):
         remaining = None
 
         results = []
-        for subtask in self.subtasks:
+        for result in self.results:
             remaining = None
             if timeout:
                 remaining = timeout - (time.time() - time_start)
                 if remaining <= 0.0:
                     raise TimeoutError("join operation timed out")
-            results.append(subtask.wait(timeout=remaining,
-                                        propagate=propagate,
-                                        interval=interval))
+            results.append(result.wait(timeout=remaining,
+                                       propagate=propagate,
+                                       interval=interval))
         return results
 
     def iter_native(self, timeout=None):
-        backend = self.subtasks[0].backend
-        ids = [subtask.task_id for subtask in self.subtasks]
+        backend = self.results[0].backend
+        ids = [result.task_id for result in self.results]
         return backend.get_many(ids, timeout=timeout)
 
     def join_native(self, timeout=None, propagate=True):
@@ -392,18 +408,54 @@ class TaskSetResult(object):
         This is currently only supported by the AMQP result backend.
 
         """
-        backend = self.subtasks[0].backend
-        results = [None for _ in xrange(len(self.subtasks))]
+        backend = self.results[0].backend
+        results = [None for _ in xrange(len(self.results))]
 
-        ids = [subtask.task_id for subtask in self.subtasks]
+        ids = [result.task_id for result in self.results]
         states = dict(backend.get_many(ids, timeout=timeout))
 
         for task_id, meta in states.items():
-            index = self.subtasks.index(task_id)
+            index = self.results.index(task_id)
             results[index] = meta["result"]
 
         return list(results)
 
+    @property
+    def total(self):
+        """Total number of tasks in the set."""
+        return len(self.results)
+
+    @property
+    def subtasks(self):
+        """Deprecated alias to :attr:`results`."""
+        return self.results
+
+
+class TaskSetResult(ResultSet):
+    """An instance of this class is returned by
+    `TaskSet`'s :meth:`~celery.task.TaskSet.apply_async` method.
+
+    It enables inspection of the tasks state and return values as a single entity.
+
+    :param taskset_id: The id of the taskset.
+    :param results: List of result instances.
+
+    """
+
+    #: The UUID of the taskset.
+    taskset_id = None
+
+    #: List/iterator of results in the taskset
+    results = None
+
+    def __init__(self, taskset_id, results=None, **kwargs):
+        self.taskset_id = taskset_id
+
+        # XXX previously the "results" arg was named "subtasks".
+        if "subtasks" in kwargs:
+            results = kwargs["subtasks"]
+        super(TaskSetResult, self).__init__(results, **kwargs)
+
     def save(self, backend=None):
         """Save taskset result for later retrieval using :meth:`restore`.
 
@@ -424,10 +476,9 @@ class TaskSetResult(object):
             backend = current_app.backend
         return backend.restore_taskset(taskset_id)
 
-    @property
-    def total(self):
-        """Total number of subtasks in the set."""
-        return len(self.subtasks)
+    def itersubtasks(self):
+        """Depreacted.   Use ``iter(self.results)`` instead."""
+        return iter(self.results)
 
 
 class EagerResult(BaseAsyncResult):