Browse Source

TaskSetResult.join + AsyncResult.wait improvements.

The implementation actually haven't changed since v0.4.0, and we can
simplify this a lot now.

An `interval` keyword argument has been added to both so the
polling interval can be specified.

A `propagate` keyword argument has been added to `result.wait()`,
errors will be included instead of raised if this is set to False.

Polling results when using the database backend is still probabably
not a good idea.
Ask Solem 14 years ago
parent
commit
dfbac4a6f9

+ 8 - 4
celery/backends/amqp.py

@@ -131,7 +131,8 @@ class AMQPBackend(BaseDictBackend):
 
         return self.poll(task_id)
 
-    def wait_for(self, task_id, timeout=None, cache=True):
+    def wait_for(self, task_id, timeout=None, cache=True, propagate=True,
+            **kwargs):
         cached_meta = self._cache.get(task_id)
 
         if cache and cached_meta and \
@@ -143,10 +144,13 @@ class AMQPBackend(BaseDictBackend):
             except socket.timeout:
                 raise TimeoutError("The operation timed out.")
 
-        if meta["status"] == states.SUCCESS:
+        state = meta["status"]
+        if state == states.SUCCESS:
+            return meta["result"]
+        elif state in states.PROPAGATE_STATES:
+            if propagate:
+                raise self.exception_to_python(meta["result"])
             return meta["result"]
-        elif meta["status"] in states.PROPAGATE_STATES:
-            raise self.exception_to_python(meta["result"])
         else:
             return self.wait_for(task_id, timeout, cache)
 

+ 7 - 5
celery/backends/base.py

@@ -71,7 +71,7 @@ class BaseBackend(object):
         raise NotImplementedError("%s does not implement forget." % (
                     self.__class__))
 
-    def wait_for(self, task_id, timeout=None):
+    def wait_for(self, task_id, timeout=None, propagate=True, interval=0.5):
         """Wait for task and return its result.
 
         If the task raises an exception, this exception
@@ -83,7 +83,6 @@ class BaseBackend(object):
 
         """
 
-        sleep_inbetween = 0.5
         time_elapsed = 0.0
 
         while True:
@@ -91,10 +90,13 @@ class BaseBackend(object):
             if status == states.SUCCESS:
                 return self.get_result(task_id)
             elif status in states.PROPAGATE_STATES:
-                raise self.get_result(task_id)
+                result = self.get_result(task_id)
+                if propagate:
+                    raise result
+                return result
             # avoid hammering the CPU checking status.
-            time.sleep(sleep_inbetween)
-            time_elapsed += sleep_inbetween
+            time.sleep(interval)
+            time_elapsed += interval
             if timeout and time_elapsed >= timeout:
                 raise TimeoutError("The operation timed out.")
 

+ 0 - 39
celery/datastructures.py

@@ -140,45 +140,6 @@ class ConfigurationView(AttributeDictMixin):
         return tuple(self.iteritems())
 
 
-class PositionQueue(UserList):
-    """A positional queue of a specific length, with slots that are either
-    filled or unfilled. When all of the positions are filled, the queue
-    is considered :meth:`full`.
-
-    :param length: Number of items to fill.
-
-    """
-
-    #: The number of items required for the queue to be considered full.
-    length = None
-
-    class UnfilledPosition(object):
-        """Describes an unfilled slot."""
-
-        def __init__(self, position):
-            # This is not used, but is an argument from xrange
-            # so why not.
-            self.position = position
-
-    def __init__(self, length):
-        self.length = length
-        self.data = map(self.UnfilledPosition, xrange(length))
-
-    def full(self):
-        """Returns :const:`True` if all of the slots has been filled."""
-        return len(self) >= self.length
-
-    def __len__(self):
-        """`len(self)` -> number of slots filled with real values."""
-        return len(self.filled)
-
-    @property
-    def filled(self):
-        """All filled slots as a list."""
-        return [slot for slot in self.data
-                    if not isinstance(slot, self.UnfilledPosition)]
-
-
 class ExceptionInfo(object):
     """Exception wrapping an exception and its traceback.
 

+ 54 - 9
celery/result.py

@@ -7,7 +7,6 @@ from itertools import imap
 
 from celery import states
 from celery.app import app_or_default
-from celery.datastructures import PositionQueue
 from celery.exceptions import TimeoutError
 from celery.registry import _unpickle_task
 from celery.utils.compat import any, all
@@ -61,11 +60,21 @@ class BaseAsyncResult(object):
         self.app.control.revoke(self.task_id, connection=connection,
                                 connect_timeout=connect_timeout)
 
-    def wait(self, timeout=None):
+    def wait(self, timeout=None, propagate=True, interval=0.5):
         """Wait for task, and return the result.
 
+        .. warning::
+
+           Waiting for subtasks may lead to deadlocks.
+           Please read :ref:`task-synchronous-subtasks`.
+
         :keyword timeout: How long to wait, in seconds, before the
                           operation times out.
+        :keyword propagate: Re-raise exception if the task failed.
+        :keyword interval: Time to wait (in seconds) before retrying to
+           retrieve the result.  Note that this does not have any effect
+           when using the AMQP result store backend, as it does not
+           use polling.
 
         :raises celery.exceptions.TimeoutError: if `timeout` is not
             :const:`None` and the result does not arrive within `timeout`
@@ -75,7 +84,9 @@ class BaseAsyncResult(object):
         be re-raised.
 
         """
-        return self.backend.wait_for(self.task_id, timeout=timeout)
+        return self.backend.wait_for(self.task_id, timeout=timeout,
+                                                   propagate=propagate,
+                                                   interval=interval)
 
     def get(self, timeout=None):
         """Alias to :meth:`wait`."""
@@ -319,24 +330,56 @@ class TaskSetResult(object):
                 elif result.status in states.PROPAGATE_STATES:
                     raise result.result
 
-    def join(self, timeout=None, propagate=True):
+    def join(self, timeout=None, propagate=True, interval=0.5):
         """Gather the results of all tasks in the taskset,
         and returns a list ordered by the order of the set.
 
+        .. note::
+
+            This can be an very expensive operation on result store
+            backends that must resort to polling (e.g. database).
+
+            You should consider using :meth:`join_native` if your backends
+            supports it.
+
+        .. warning::
+
+            Waiting for subtasks may lead the deadlocks.
+            Please see :ref:`task-synchronous-subtasks`.
+
         :keyword timeout: The number of seconds to wait for results before
                           the operation times out.
 
         :keyword propagate: If any of the subtasks raises an exception, the
                             exception will be reraised.
 
+        :keyword interval: Time to wait (in seconds) before retrying to
+                           retrieve a result from the set.  Note that this
+                           does not have any effect when using the AMQP
+                           result store backend, as it does not use polling.
+
         :raises celery.exceptions.TimeoutError: if `timeout` is not
             :const:`None` and the operation takes longer than `timeout`
             seconds.
 
         """
-
         time_start = time.time()
-        results = PositionQueue(length=self.total)
+        remaining = None
+
+        results = []
+        for subtask in self.subtasks:
+            remaining = None
+            if timeout:
+                remaining = timeout - (time.time() - time_start)
+                if remaining <= 0.0:
+                    raise TimeoutError("join operation timed out")
+            results.append(subtask.wait(timeout=remaining,
+                                        propagate=propagate,
+                                        interval=interval))
+        return results
+
+
+
 
         while True:
             for position, pending_result in enumerate(self.subtasks):
@@ -371,7 +414,7 @@ class TaskSetResult(object):
 
         """
         backend = self.subtasks[0].backend
-        results = PositionQueue(length=self.total)
+        results = [None for _ in xrange(len(self.subtasks))]
 
         ids = [subtask.task_id for subtask in self.subtasks]
         states = dict(backend.get_many(ids, timeout=timeout))
@@ -426,12 +469,14 @@ class EagerResult(BaseAsyncResult):
         """Returns :const:`True` if the task has been executed."""
         return True
 
-    def wait(self, timeout=None):
+    def wait(self, timeout=None, propagate=True, **kwargs):
         """Wait until the task has been executed and return its result."""
         if self.state == states.SUCCESS:
             return self.result
         elif self.state in states.PROPAGATE_STATES:
-            raise self.result
+            if propagate:
+                raise self.result
+            return self.result
 
     def revoke(self):
         self._state = states.REVOKED

+ 1 - 31
celery/tests/test_datastructures.py

@@ -2,7 +2,7 @@ import sys
 from celery.tests.utils import unittest
 from Queue import Queue
 
-from celery.datastructures import PositionQueue, ExceptionInfo, LocalCache
+from celery.datastructures import ExceptionInfo, LocalCache
 from celery.datastructures import LimitedSet, SharedCounter, consume_queue
 from celery.datastructures import AttributeDict, DictAttribute
 from celery.datastructures import ConfigurationView
@@ -71,36 +71,6 @@ class test_ConfigurationView(unittest.TestCase):
         self.assertDictEqual(dict(self.view.items()), expected)
 
 
-class test_PositionQueue(unittest.TestCase):
-
-    def test_position_queue_unfilled(self):
-        q = PositionQueue(length=10)
-        for position in q.data:
-            self.assertIsInstance(position, q.UnfilledPosition)
-
-        self.assertListEqual(q.filled, [])
-        self.assertEqual(len(q), 0)
-        self.assertFalse(q.full())
-
-    def test_position_queue_almost(self):
-        q = PositionQueue(length=10)
-        q[3] = 3
-        q[6] = 6
-        q[9] = 9
-
-        self.assertListEqual(q.filled, [3, 6, 9])
-        self.assertEqual(len(q), 3)
-        self.assertFalse(q.full())
-
-    def test_position_queue_full(self):
-        q = PositionQueue(length=10)
-        for i in xrange(10):
-            q[i] = i
-        self.assertListEqual(q.filled, list(xrange(10)))
-        self.assertEqual(len(q), 10)
-        self.assertTrue(q.full())
-
-
 class test_ExceptionInfo(unittest.TestCase):
 
     def test_exception_info(self):