Browse Source

[Redis][async] Fixes waiting for groups, and more

Ask Solem 9 years ago
parent
commit
e6835bdbd2
5 changed files with 56 additions and 35 deletions
  1. 24 21
      celery/backends/async.py
  2. 3 1
      celery/backends/base.py
  3. 5 0
      celery/backends/redis.py
  4. 10 2
      celery/result.py
  5. 14 11
      celery/tests/tasks/test_result.py

+ 24 - 21
celery/backends/async.py

@@ -8,9 +8,9 @@
 from __future__ import absolute_import, unicode_literals
 
 import socket
-import time
 
 from collections import deque
+from time import sleep
 from weakref import WeakKeyDictionary
 
 from kombu.syn import detect_environment
@@ -82,7 +82,7 @@ class greenletDrainer(Drainer):
         if self._g is None:
             self.start()
         if not p.ready:
-            time.sleep(0)
+            sleep(0)
 
 
 @register_drainer('eventlet')
@@ -115,22 +115,22 @@ class AsyncBackendMixin(object):
             raise StopIteration()
 
         bucket = deque()
-        for result in results:
-            if result._cache:
-                bucket.append(result)
+        for node in results:
+            if node._cache:
+                bucket.append(node)
             else:
-                self._collect_into(result, bucket)
+                self._collect_into(node, bucket)
 
         for _ in self._wait_for_pending(
                 result,
                 timeout=timeout, interval=interval, no_ack=no_ack,
                 on_message=on_message, on_interval=on_interval):
             while bucket:
-                result = bucket.popleft()
-                yield result.id, result._cache
+                node = bucket.popleft()
+                yield result.id, node._cache
         while bucket:
-            result = bucket.popleft()
-            yield result.id, result._cache
+            node = bucket.popleft()
+            yield result.id, node._cache
 
     def add_pending_result(self, result):
         if result.id not in self._pending_results:
@@ -152,13 +152,12 @@ class AsyncBackendMixin(object):
             pass
         return result.maybe_throw(callback=callback, propagate=propagate)
 
-    def _wait_for_pending(self, result, timeout=None, interval=0.5,
-                          no_ack=True, on_interval=None, on_message=None,
-                          callback=None, propagate=True):
+    def _wait_for_pending(self, result,
+                          timeout=None, on_interval=None, on_message=None,
+                          **kwargs):
         return self.result_consumer._wait_for_pending(
-            result, timeout=timeout, interval=interval,
-            no_ack=no_ack, on_interval=on_interval,
-            callback=callback, on_message=on_message, propagate=propagate,
+            result, timeout=timeout,
+            on_interval=on_interval, on_message=on_message,
         )
 
     @property
@@ -205,21 +204,25 @@ class BaseResultConsumer(object):
         return self.drainer.drain_events_until(
             p, timeout=timeout, on_interval=on_interval)
 
-    def _wait_for_pending(self, result, timeout=None, interval=0.5,
-                          no_ack=True, on_interval=None, callback=None,
-                          on_message=None, propagate=True):
+    def _wait_for_pending(self, result,
+                          timeout=None, on_interval=None, on_message=None,
+                          **kwargs):
+        self.on_wait_for_pending(result, timeout=timeout, **kwargs)
         prev_on_m, self.on_message = self.on_message, on_message
         try:
             for _ in self.drain_events_until(
                     result.on_ready, timeout=timeout,
                     on_interval=on_interval):
                 yield
-                time.sleep(0)
+                sleep(0)
         except socket.timeout:
             raise TimeoutError('The operation timed out.')
         finally:
             self.on_message = prev_on_m
 
+    def on_wait_for_pending(self, result, timeout=None, **kwargs):
+        pass
+
     def on_out_of_band_result(self, message):
         self.on_state_change(message.payload, message)
 
@@ -238,4 +241,4 @@ class BaseResultConsumer(object):
                 buckets.pop(result)
             except KeyError:
                 pass
-        time.sleep(0)
+        sleep(0)

+ 3 - 1
celery/backends/base.py

@@ -537,7 +537,7 @@ class BaseKeyValueStoreBackend(Backend):
             }
 
     def get_many(self, task_ids, timeout=None, interval=0.5, no_ack=True,
-                 on_message=None, on_interval=None,
+                 on_message=None, on_interval=None, max_iterations=None,
                  READY_STATES=states.READY_STATES):
         interval = 0.5 if interval is None else interval
         ids = task_ids if isinstance(task_ids, set) else set(task_ids)
@@ -571,6 +571,8 @@ class BaseKeyValueStoreBackend(Backend):
                 on_interval()
             time.sleep(interval)  # don't busy loop.
             iterations += 1
+            if max_iterations and iterations >= max_iterations:
+                break
 
     def _forget(self, task_id):
         self.delete(self.get_key_for_task(task_id))

+ 5 - 0
celery/backends/redis.py

@@ -65,6 +65,11 @@ class ResultConsumer(async.BaseResultConsumer):
         )
         self._consume_from(initial_task_id)
 
+    def on_wait_for_pending(self, result, **kwargs):
+        for meta in result._iter_meta():
+            if meta is not None:
+                self.on_state_change(meta, None)
+
     def stop(self):
         if self._pubsub is not None:
             self._pubsub.close()

+ 10 - 2
celery/result.py

@@ -15,7 +15,7 @@ from contextlib import contextmanager
 from copy import copy
 
 from kombu.utils import cached_property
-from vine import Thenable, promise
+from vine import Thenable, barrier, promise
 
 from . import current_app
 from . import states
@@ -356,6 +356,9 @@ class AsyncResult(ResultBase):
             return self._maybe_set_cache(self.backend.get_task_meta(self.id))
         return self._cache
 
+    def _iter_meta(self):
+        return iter([self._get_task_meta()])
+
     def _set_cache(self, d):
         children = d.get('children')
         if children:
@@ -438,7 +441,7 @@ class ResultSet(ResultBase):
         self._cache = None
         self.results = results
         self.on_ready = promise(args=(self,))
-        self._on_full = ready_barrier
+        self._on_full = ready_barrier or barrier(results)
         if self._on_full:
             self._on_full.then(promise(self.on_ready))
 
@@ -737,6 +740,11 @@ class ResultSet(ResultBase):
                 acc[order_index[task_id]] = value
         return acc
 
+    def _iter_meta(self):
+        return (meta for _, meta in self.backend.get_many(
+            {r.id for r in self.results}, max_iterations=1,
+        ))
+
     def _failed_join_report(self):
         return (res for res in self.results
                 if res.backend.is_cached(res.id) and

+ 14 - 11
celery/tests/tasks/test_result.py

@@ -320,8 +320,11 @@ class test_ResultSet(AppCase):
             [self.app.AsyncResult(t) for t in ['1', '2', '3']])))
 
     def test_eq_other(self):
-        self.assertFalse(self.app.ResultSet([1, 3, 3]) == 1)
-        self.assertTrue(self.app.ResultSet([1]) == self.app.ResultSet([1]))
+        self.assertFalse(self.app.ResultSet(
+            [self.app.AsyncResult(t) for t in [1, 3, 3]]) == 1)
+        rs1 = self.app.ResultSet([self.app.AsyncResult(1)])
+        rs2 = self.app.ResultSet([self.app.AsyncResult(1)])
+        self.assertTrue(rs1 == rs2)
 
     def test_get(self):
         x = self.app.ResultSet([self.app.AsyncResult(t) for t in [1, 2, 3]])
@@ -336,18 +339,18 @@ class test_ResultSet(AppCase):
         self.assertTrue(x.join_native.called)
 
     def test_eq_ne(self):
-        g1 = self.app.ResultSet(
+        g1 = self.app.ResultSet([
             self.app.AsyncResult('id1'),
             self.app.AsyncResult('id2'),
-        )
-        g2 = self.app.ResultSet(
+        ])
+        g2 = self.app.ResultSet([
             self.app.AsyncResult('id1'),
             self.app.AsyncResult('id2'),
-        )
-        g3 = self.app.ResultSet(
+        ])
+        g3 = self.app.ResultSet([
             self.app.AsyncResult('id3'),
             self.app.AsyncResult('id1'),
-        )
+        ])
         self.assertEqual(g1, g2)
         self.assertNotEqual(g1, g3)
         self.assertNotEqual(g1, object())
@@ -366,10 +369,10 @@ class test_ResultSet(AppCase):
         self.assertTrue(x.join.called)
 
     def test_add(self):
-        x = self.app.ResultSet([1])
-        x.add(2)
+        x = self.app.ResultSet([self.app.AsyncResult(1)])
+        x.add(self.app.AsyncResult(2))
         self.assertEqual(len(x), 2)
-        x.add(2)
+        x.add(self.app.AsyncResult(2))
         self.assertEqual(len(x), 2)
 
     @contextmanager