Преглед на файлове

[async result] Callback based result backends (related to Issue #2529)

Ask Solem преди 9 години
родител
ревизия
9e31b2790c
променени са 4 файла, в които са добавени 248 реда и са изтрити 79 реда
  1. 136 53
      celery/backends/amqp.py
  2. 14 0
      celery/backends/base.py
  3. 68 26
      celery/result.py
  4. 30 0
      funtests/stress/t.py

+ 136 - 53
celery/backends/amqp.py

@@ -46,12 +46,122 @@ class NoCacheQueue(Queue):
     can_cache_declaration = False
 
 
+class ResultConsumer(object):
+    Consumer = Consumer
+
+    def __init__(self, backend, app, accept, pending_results):
+        self.backend = backend
+        self.app = app
+        self.accept = accept
+        self._pending_results = pending_results
+        self._consumer = None
+        self._conn = None
+        self.on_message = None
+        self.bucket = None
+
+    def consume(self, task_id, timeout=None, no_ack=True, on_interval=None):
+        wait = self.drain_events
+        with self.app.pool.acquire_channel(block=True) as (conn, channel):
+            binding = self.backend._create_binding(task_id)
+            with self.Consumer(channel, binding,
+                               no_ack=no_ack, accept=self.accept) as consumer:
+                while 1:
+                    try:
+                        return wait(
+                            conn, consumer, timeout, on_interval)[task_id]
+                    except KeyError:
+                        continue
+
+    def wait_for_pending(self, result,
+                         callback=None, propagate=True, **kwargs):
+        for _ in self._wait_for_pending(result, **kwargs):
+            pass
+        return result.maybe_throw(callback=callback, propagate=propagate)
+
+    def _wait_for_pending(self, result, timeout=None, interval=0.5,
+                 no_ack=True, on_interval=None, callback=None,
+                 on_message=None, propagate=True):
+        prev_on_m, self.on_message = self.on_message, on_message
+        try:
+            for _ in self.drain_events_until(
+                    result.on_ready, timeout=timeout,
+                    on_interval=on_interval):
+                yield
+        except socket.timeout:
+            raise TimeoutError('The operation timed out.')
+        finally:
+            self.on_message = prev_on_m
+
+    def collect_for_pending(self, result, bucket=None, **kwargs):
+        prev_bucket, self.bucket = self.bucket, bucket
+        try:
+            for _ in self._wait_for_pending(result, **kwargs):
+                yield
+        finally:
+            self.bucket = prev_bucket
+
+    def start(self, initial_queue, no_ack=True):
+        self._conn = self.app.connection()
+        self._consumer = self.Consumer(
+            self._conn.default_channel, [initial_queue],
+            callbacks=[self.on_state_change], no_ack=no_ack,
+            accept=self.accept)
+        self._consumer.consume()
+
+    def stop(self):
+        try:
+            self._consumer.cancel()
+        finally:
+            self._connection.close()
+
+    def consume_from(self, queue):
+        if self._consumer is None:
+            return self.start(queue)
+        if not self._consumer.consuming_from(queue):
+            self._consumer.add_queue(queue)
+            self._consumer.consume()
+
+    def cancel_for(self, queue):
+        self._consumer.cancel_by_queue(queue)
+
+    def on_state_change(self, meta, message):
+        if self.on_message:
+            self.on_message(meta)
+        if meta['status'] in states.READY_STATES:
+            try:
+                result = self._pending_results[meta['task_id']]
+            except KeyError:
+                return
+            result._maybe_set_cache(meta)
+            if self.bucket is not None:
+                self.bucket.append(result)
+
+    def drain_events_until(self, p, timeout=None, on_interval=None,
+                           monotonic=monotonic, wait=None):
+        wait = wait or self._conn.drain_events
+        time_start = monotonic()
+
+        while 1:
+            # Total time spent may exceed a single call to wait()
+            if timeout and monotonic() - time_start >= timeout:
+                raise socket.timeout()
+            try:
+                yield wait(timeout=1)
+            except socket.timeout:
+                pass
+            if on_interval:
+                on_interval()
+            if p.ready:  # got event on the wanted channel.
+                break
+
+
 class AMQPBackend(BaseBackend):
     """Publishes results by sending messages."""
     Exchange = Exchange
     Queue = NoCacheQueue
     Consumer = Consumer
     Producer = Producer
+    ResultConsumer = ResultConsumer
 
     BacklogLimitExceeded = BacklogLimitExceeded
 
@@ -83,6 +193,8 @@ class AMQPBackend(BaseBackend):
         self.queue_arguments = dictfilter({
             'x-expires': maybe_s_to_ms(self.expires),
         })
+        self.result_consumer = self.ResultConsumer(
+            self, self.app, self.accept, self._pending_results)
 
     def _create_exchange(self, name, type='direct', delivery_mode=2):
         return self.Exchange(name=name,
@@ -136,22 +248,6 @@ class AMQPBackend(BaseBackend):
     def on_reply_declare(self, task_id):
         return [self._create_binding(task_id)]
 
-    def wait_for(self, task_id, timeout=None, cache=True,
-                 no_ack=True, on_interval=None,
-                 READY_STATES=states.READY_STATES,
-                 PROPAGATE_STATES=states.PROPAGATE_STATES,
-                 **kwargs):
-        cached_meta = self._cache.get(task_id)
-        if cache and cached_meta and \
-                cached_meta['status'] in READY_STATES:
-            return cached_meta
-        else:
-            try:
-                return self.consume(task_id, timeout=timeout, no_ack=no_ack,
-                                    on_interval=on_interval)
-            except socket.timeout:
-                raise TimeoutError('The operation timed out.')
-
     def get_task_meta(self, task_id, backlog_limit=1000):
         # Polling and using basic_get
         with self.app.pool.acquire_channel(block=True) as (_, channel):
@@ -189,50 +285,37 @@ class AMQPBackend(BaseBackend):
                     return {'status': states.PENDING, 'result': None}
     poll = get_task_meta  # XXX compat
 
-    def drain_events(self, connection, consumer,
-                     timeout=None, on_interval=None, now=monotonic, wait=None):
-        wait = wait or connection.drain_events
-        results = {}
+    def wait_for_pending(self, result, timeout=None, interval=0.5,
+                 no_ack=True, on_interval=None, on_message=None,
+                 callback=None, propagate=True):
+        return self.result_consumer.wait_for_pending(
+            result, timeout=timeout, interval=interval,
+            no_ack=no_ack, on_interval=on_interval,
+            callback=callback, on_message=on_message, propagate=propagate,
+        )
 
-        def callback(meta, message):
-            if meta['status'] in states.READY_STATES:
-                results[meta['task_id']] = self.meta_from_decoded(meta)
+    def collect_for_pending(self, result, bucket=None, timeout=None,
+                            interval=0.5, no_ack=True, on_interval=None,
+                            on_message=None, callback=None, propagate=True):
+        return self.result_consumer.collect_for_pending(
+            result, bucket=bucket, timeout=timeout, interval=interval,
+            no_ack=no_ack, on_interval=on_interval,
+            callback=callback, on_message=on_message, propagate=propagate,
+        )
 
-        consumer.callbacks[:] = [callback]
-        time_start = now()
+    def add_pending_result(self, result):
+        if result.id not in self._pending_results:
+            self._pending_results[result.id] = result
+            self.result_consumer.consume_from(self._create_binding(result.id))
 
-        while 1:
-            # Total time spent may exceed a single call to wait()
-            if timeout and now() - time_start >= timeout:
-                raise socket.timeout()
-            try:
-                wait(timeout=1)
-            except socket.timeout:
-                pass
-            if on_interval:
-                on_interval()
-            if results:  # got event on the wanted channel.
-                break
-        self._cache.update(results)
-        return results
-
-    def consume(self, task_id, timeout=None, no_ack=True, on_interval=None):
-        wait = self.drain_events
-        with self.app.pool.acquire_channel(block=True) as (conn, channel):
-            binding = self._create_binding(task_id)
-            with self.Consumer(channel, binding,
-                               no_ack=no_ack, accept=self.accept) as consumer:
-                while 1:
-                    try:
-                        return wait(
-                            conn, consumer, timeout, on_interval)[task_id]
-                    except KeyError:
-                        continue
+    def remove_pending_result(self, result):
+        self._pending_results.pop(result.id, None)
+        # XXX cancel queue after result consumed
 
     def _many_bindings(self, ids):
         return [self._create_binding(task_id) for task_id in ids]
 
-    def get_many(self, task_ids, timeout=None, no_ack=True,
+    def xxx_get_many(self, task_ids, timeout=None, no_ack=True,
                  on_message=None, on_interval=None,
                  now=monotonic, getfields=itemgetter('status', 'task_id'),
                  READY_STATES=states.READY_STATES,

+ 14 - 0
celery/backends/base.py

@@ -107,6 +107,7 @@ class BaseBackend(object):
         self.accept = prepare_accept_content(
             conf.accept_content if accept is None else accept,
         )
+        self._pending_results = {}
 
     def mark_as_started(self, task_id, **meta):
         """Mark a task as started"""
@@ -221,6 +222,19 @@ class BaseBackend(object):
                      content_encoding=self.content_encoding,
                      accept=self.accept)
 
+    def wait_for_pending(self, result, timeout=None, interval=0.5,
+                         no_ack=True, on_interval=None, callback=None,
+                         propagate=True):
+        meta = self.wait_for(
+            result.id, timeout=timeout,
+            interval=interval,
+            on_interval=on_interval,
+            no_ack=no_ack,
+        )
+        if meta:
+            result._maybe_set_cache(meta)
+            return result.maybe_throw(propagate=propagate, callback=callback)
+
     def wait_for(self, task_id,
                  timeout=None, interval=0.5, no_ack=True, on_interval=None):
         """Wait for task and return its result.

+ 68 - 26
celery/result.py

@@ -14,7 +14,7 @@ from collections import OrderedDict, deque
 from contextlib import contextmanager
 from copy import copy
 
-from amqp import promise
+from amqp.promise import Thenable, barrier, promise
 from kombu.utils import cached_property
 
 from . import current_app
@@ -86,8 +86,17 @@ class AsyncResult(ResultBase):
         self.id = id
         self.backend = backend or self.app.backend
         self.parent = parent
+        self.on_ready = promise(self._on_fulfilled)
         self._cache = None
 
+    def then(self, callback, on_error=None):
+        self.backend.add_pending_result(self)
+        return self.on_ready.then(callback, on_error)
+
+    def _on_fulfilled(self, result):
+        self.backend.remove_pending_result(self)
+        return result
+
     def as_tuple(self):
         parent = self.parent
         return (self.id, parent and parent.as_tuple()), None
@@ -159,28 +168,22 @@ class AsyncResult(ResultBase):
 
         if self._cache:
             if propagate:
-                self.maybe_reraise()
+                self.maybe_throw()
             return self.result
 
-        meta = self.backend.wait_for(
-            self.id, timeout=timeout,
+        self.backend.add_pending_result(self)
+        return self.backend.wait_for_pending(
+            self, timeout=timeout,
             interval=interval,
             on_interval=_on_interval,
             no_ack=no_ack,
+            propagate=propagate,
         )
-        if meta:
-            self._maybe_set_cache(meta)
-            state = meta['status']
-            if state in PROPAGATE_STATES and propagate:
-                raise meta['result']
-            if callback is not None:
-                callback(self.id, meta['result'])
-            return meta['result']
     wait = get  # deprecated alias to :meth:`get`.
 
     def _maybe_reraise_parent_error(self):
         for node in reversed(list(self._parents())):
-            node.maybe_reraise()
+            node.maybe_throw()
 
     def _parents(self):
         node = self.parent
@@ -268,9 +271,17 @@ class AsyncResult(ResultBase):
         """Returns :const:`True` if the task failed."""
         return self.state == states.FAILURE
 
-    def maybe_reraise(self):
-        if self.state in states.PROPAGATE_STATES:
-            raise self.result
+    def throw(self, *args, **kwargs):
+        self.on_ready.throw(*args, **kwargs)
+
+    def maybe_throw(self, propagate=True, callback=None):
+        cache = self._get_task_meta() if self._cache is None else self._cache
+        state, value = cache['status'], cache['result']
+        if state in states.PROPAGATE_STATES and propagate:
+            self.throw(value)
+        if callback is not None:
+            callback(self.id, value)
+        return value
 
     def build_graph(self, intermediate=False, formatter=None):
         graph = DependencyGraph(
@@ -333,8 +344,10 @@ class AsyncResult(ResultBase):
     def _maybe_set_cache(self, meta):
         if meta:
             state = meta['status']
-            if state == states.SUCCESS or state in states.PROPAGATE_STATES:
-                return self._set_cache(meta)
+            if state in states.READY_STATES:
+                d = self._set_cache(self.backend.meta_from_decoded(meta))
+                self.on_ready(self)
+                return d
         return meta
 
     def _get_task_meta(self):
@@ -405,6 +418,7 @@ class AsyncResult(ResultBase):
     @task_id.setter  # noqa
     def task_id(self, id):
         self.id = id
+Thenable.register(AsyncResult)
 
 
 class ResultSet(ResultBase):
@@ -421,6 +435,7 @@ class ResultSet(ResultBase):
     def __init__(self, results, app=None, **kwargs):
         self._app = app
         self.results = results
+        self.on_ready = barrier(self.results, (self,), callback=self._on_ready)
 
     def add(self, result):
         """Add :class:`AsyncResult` as a new member of the set.
@@ -430,6 +445,10 @@ class ResultSet(ResultBase):
         """
         if result not in self.results:
             self.results.append(result)
+            self.ready.add(result)
+
+    def _on_ready(self, result):
+        self.backend.remove_pending_result(result)
 
     def remove(self, result):
         """Remove result from the set; it must be a member.
@@ -482,9 +501,9 @@ class ResultSet(ResultBase):
         """
         return any(result.failed() for result in self.results)
 
-    def maybe_reraise(self):
+    def maybe_throw(self, callback=None, propagate=True):
         for result in self.results:
-            result.maybe_reraise()
+            result.maybe_throw(callback=callback, propagate=propagate)
 
     def waiting(self):
         """Are any of the tasks incomplete?
@@ -655,6 +674,12 @@ class ResultSet(ResultBase):
                 results.append(value)
         return results
 
+    def then(self, callback, on_error=None):
+        for result in self.results:
+            self.backend.add_pending_result(result)
+            result.on_ready.then(self.on_ready)
+        return self.on_ready.then(callback, on_error)
+
     def iter_native(self, timeout=None, interval=0.5, no_ack=True,
                     on_message=None, on_interval=None):
         """Backend optimized version of :meth:`iterate`.
@@ -670,12 +695,21 @@ class ResultSet(ResultBase):
         """
         results = self.results
         if not results:
-            return iter([])
-        return self.backend.get_many(
-            {r.id for r in results},
-            timeout=timeout, interval=interval, no_ack=no_ack,
-            on_message=on_message, on_interval=on_interval,
-        )
+            raise StopIteration()
+        ids = set()
+        for result in self.results:
+            self.backend.add_pending_result(result)
+            ids.add(result.id)
+        bucket = deque()
+        for _ in  self.backend.collect_for_pending(
+                self,
+                bucket=bucket,
+                timeout=timeout, interval=interval, no_ack=no_ack,
+                on_message=on_message, on_interval=on_interval):
+            while bucket:
+                result = bucket.popleft()
+                if result.id in ids:
+                    yield result.id, result._cache
 
     def join_native(self, timeout=None, propagate=True,
                     interval=0.5, callback=None, no_ack=True,
@@ -749,6 +783,7 @@ class ResultSet(ResultBase):
     @property
     def backend(self):
         return self.app.backend if self.app else self.results[0].backend
+Thenable.register(ResultSet)
 
 
 class GroupResult(ResultSet):
@@ -822,6 +857,7 @@ class GroupResult(ResultSet):
         return (
             backend or (self.app.backend if self.app else current_app.backend)
         ).restore_group(id)
+Thenable.register(ResultSet)
 
 
 class EagerResult(AsyncResult):
@@ -832,6 +868,11 @@ class EagerResult(AsyncResult):
         self._result = ret_value
         self._state = state
         self._traceback = traceback
+        self.on_ready = promise()
+        self.on_ready()
+
+    def then(self, callback, on_error=None):
+        return self.on_ready.then(callback, on_error)
 
     def _get_task_meta(self):
         return {'task_id': self.id, 'result': self._result, 'status':
@@ -887,6 +928,7 @@ class EagerResult(AsyncResult):
     @property
     def supports_native_join(self):
         return False
+Thenable.register(EagerResult)
 
 
 def result_from_tuple(r, app=None):

+ 30 - 0
funtests/stress/t.py

@@ -0,0 +1,30 @@
+from celery import group
+import socket
+from stress.app import add, raising
+
+def on_ready(result):
+    print('RESULT: %r' % (result,))
+
+def test():
+    group(add.s(i, i) for i in range(10)).delay().then(on_ready)
+
+    p = group(add.s(i, i) for i in range(10)).delay()
+    print(p.get(timeout=5))
+
+    p = add.delay(2, 2)
+    print(p.get(timeout=5))
+    p = add.delay(2, 2)
+    print(p.get(timeout=5))
+    p = add.delay(2, 2)
+    print(p.get(timeout=5))
+    p = add.delay(2, 2)
+    print(p.get(timeout=5))
+    p = raising.delay()
+    try:
+        print(p.get(timeout=5))
+    except Exception as exc:
+        print('raised: %r' % (exc),)
+
+
+for i in range(100):
+    test()