Browse Source

AsyncResult.get now supports the on_interval argument

Ask Solem 9 years ago
parent
commit
2055cbd056
3 changed files with 21 additions and 13 deletions
  1. 4 1
      celery/backends/amqp.py
  2. 3 1
      celery/backends/base.py
  3. 14 11
      celery/result.py

+ 4 - 1
celery/backends/amqp.py

@@ -231,7 +231,8 @@ class AMQPBackend(BaseBackend):
     def _many_bindings(self, ids):
         return [self._create_binding(task_id) for task_id in ids]
 
-    def get_many(self, task_ids, timeout=None, no_ack=True, on_message=None,
+    def get_many(self, task_ids, timeout=None, no_ack=True,
+                 on_message=None, on_interval=None,
                  now=monotonic, getfields=itemgetter('status', 'task_id'),
                  READY_STATES=states.READY_STATES,
                  PROPAGATE_STATES=states.PROPAGATE_STATES, **kwargs):
@@ -276,6 +277,8 @@ class AMQPBackend(BaseBackend):
                         ids.discard(task_id)
                         push_cache(task_id, state)
                         yield task_id, state
+                    if on_interval:
+                        on_interval()
 
     def reload_task_result(self, task_id):
         raise NotImplementedError(

+ 3 - 1
celery/backends/base.py

@@ -475,7 +475,7 @@ class KeyValueStoreBackend(BaseBackend):
             }
 
     def get_many(self, task_ids, timeout=None, interval=0.5, no_ack=True,
-                 on_message=None,
+                 on_message=None, on_interval=None,
                  READY_STATES=states.READY_STATES):
         interval = 0.5 if interval is None else interval
         ids = task_ids if isinstance(task_ids, set) else set(task_ids)
@@ -505,6 +505,8 @@ class KeyValueStoreBackend(BaseBackend):
                 yield bytes_to_str(key), value
             if timeout and iterations * interval >= timeout:
                 raise TimeoutError('Operation timed out ({0})'.format(timeout))
+            if on_interval:
+                on_interval()
             time.sleep(interval)  # don't busy loop.
             iterations += 1
 

+ 14 - 11
celery/result.py

@@ -14,6 +14,7 @@ from collections import OrderedDict, deque
 from contextlib import contextmanager
 from copy import copy
 
+from amqp import promise
 from kombu.utils import cached_property
 
 from . import current_app
@@ -118,7 +119,7 @@ class AsyncResult(ResultBase):
                                 reply=wait, timeout=timeout)
 
     def get(self, timeout=None, propagate=True, interval=0.5,
-            no_ack=True, follow_parents=True, callback=None,
+            no_ack=True, follow_parents=True, callback=None, on_interval=None,
             EXCEPTION_STATES=states.EXCEPTION_STATES,
             PROPAGATE_STATES=states.PROPAGATE_STATES):
         """Wait until task is ready, and return its result.
@@ -149,10 +150,12 @@ class AsyncResult(ResultBase):
 
         """
         assert_will_not_block()
-        on_interval = None
+        _on_interval = promise()
         if follow_parents and propagate and self.parent:
-            on_interval = self._maybe_reraise_parent_error
-            on_interval()
+            on_interval = promise(self._maybe_reraise_parent_error)
+            self._maybe_reraise_parent_error()
+        if on_interval:
+            _on_interval.then(on_interval)
 
         if self._cache:
             if propagate:
@@ -162,7 +165,7 @@ class AsyncResult(ResultBase):
         meta = self.backend.wait_for(
             self.id, timeout=timeout,
             interval=interval,
-            on_interval=on_interval,
+            on_interval=_on_interval,
             no_ack=no_ack,
         )
         if meta:
@@ -579,7 +582,7 @@ class ResultSet(ResultBase):
         )
 
     def join(self, timeout=None, propagate=True, interval=0.5,
-             callback=None, no_ack=True, on_message=None):
+             callback=None, no_ack=True, on_message=None, on_interval=None):
         """Gathers the results of all tasks as a list in order.
 
         .. note::
@@ -644,7 +647,7 @@ class ResultSet(ResultBase):
                     raise TimeoutError('join operation timed out')
             value = result.get(
                 timeout=remaining, propagate=propagate,
-                interval=interval, no_ack=no_ack,
+                interval=interval, no_ack=no_ack, on_interval=on_interval,
             )
             if callback:
                 callback(result.id, value)
@@ -653,7 +656,7 @@ class ResultSet(ResultBase):
         return results
 
     def iter_native(self, timeout=None, interval=0.5, no_ack=True,
-                    on_message=None):
+                    on_message=None, on_interval=None):
         """Backend optimized version of :meth:`iterate`.
 
         .. versionadded:: 2.2
@@ -671,12 +674,12 @@ class ResultSet(ResultBase):
         return self.backend.get_many(
             {r.id for r in results},
             timeout=timeout, interval=interval, no_ack=no_ack,
-            on_message=on_message,
+            on_message=on_message, on_interval=on_interval,
         )
 
     def join_native(self, timeout=None, propagate=True,
                     interval=0.5, callback=None, no_ack=True,
-                    on_message=None):
+                    on_message=None, on_interval=None):
         """Backend optimized version of :meth:`join`.
 
         .. versionadded:: 2.2
@@ -694,7 +697,7 @@ class ResultSet(ResultBase):
         }
         acc = None if callback else [None for _ in range(len(self))]
         for task_id, meta in self.iter_native(timeout, interval, no_ack,
-                                              on_message):
+                                              on_message, on_interval):
             value = meta['result']
             if propagate and meta['status'] in states.PROPAGATE_STATES:
                 raise value