Ver código fonte

Fixed chord as a chord callback for amqp result backend (also closes #1905)

Ask Solem 11 anos atrás
pai
commit
881786cffe
3 arquivos alterados com 65 adições e 27 exclusões
  1. 11 7
      celery/backends/amqp.py
  2. 6 2
      celery/backends/base.py
  3. 48 18
      celery/result.py

+ 11 - 7
celery/backends/amqp.py

@@ -141,6 +141,7 @@ class AMQPBackend(BaseBackend):
         return [self._create_binding(task_id)]
 
     def wait_for(self, task_id, timeout=None, cache=True, propagate=True,
+                 no_ack=True, on_interval=None,
                  READY_STATES=states.READY_STATES,
                  PROPAGATE_STATES=states.PROPAGATE_STATES,
                  **kwargs):
@@ -150,7 +151,8 @@ class AMQPBackend(BaseBackend):
             meta = cached_meta
         else:
             try:
-                meta = self.consume(task_id, timeout=timeout)
+                meta = self.consume(task_id, timeout=timeout, no_ack=no_ack,
+                                    on_interval=on_interval)
             except socket.timeout:
                 raise TimeoutError('The operation timed out.')
 
@@ -193,7 +195,7 @@ class AMQPBackend(BaseBackend):
     poll = get_task_meta  # XXX compat
 
     def drain_events(self, connection, consumer,
-                     timeout=None, now=monotonic, wait=None):
+                     timeout=None, on_interval=None, now=monotonic, wait=None):
         wait = wait or connection.drain_events
         results = {}
 
@@ -209,27 +211,29 @@ class AMQPBackend(BaseBackend):
             if timeout and now() - time_start >= timeout:
                 raise socket.timeout()
             wait(timeout=timeout)
+            if on_interval:
+                on_interval()
             if results:  # got event on the wanted channel.
                 break
         self._cache.update(results)
         return results
 
-    def consume(self, task_id, timeout=None):
+    def consume(self, task_id, timeout=None, no_ack=True, on_interval=None):
         wait = self.drain_events
         with self.app.pool.acquire_channel(block=True) as (conn, channel):
             binding = self._create_binding(task_id)
             with self.Consumer(channel, binding,
-                               no_ack=True, accept=self.accept) as consumer:
+                               no_ack=no_ack, accept=self.accept) as consumer:
                 while 1:
                     try:
-                        return wait(conn, consumer, timeout)[task_id]
+                        return wait(conn, consumer, timeout, on_interval)[task_id]
                     except KeyError:
                         continue
 
     def _many_bindings(self, ids):
         return [self._create_binding(task_id) for task_id in ids]
 
-    def get_many(self, task_ids, timeout=None,
+    def get_many(self, task_ids, timeout=None, no_ack=True,
                  now=monotonic, getfields=itemgetter('status', 'task_id'),
                  READY_STATES=states.READY_STATES,
                  PROPAGATE_STATES=states.PROPAGATE_STATES, **kwargs):
@@ -263,7 +267,7 @@ class AMQPBackend(BaseBackend):
 
             bindings = self._many_bindings(task_ids)
             with self.Consumer(channel, bindings, on_message=on_message,
-                               accept=self.accept, no_ack=True):
+                               accept=self.accept, no_ack=no_ack):
                 wait = conn.drain_events
                 popleft = results.popleft
                 while ids:

+ 6 - 2
celery/backends/base.py

@@ -177,7 +177,9 @@ class BaseBackend(object):
                      content_encoding=self.content_encoding,
                      accept=self.accept)
 
-    def wait_for(self, task_id, timeout=None, propagate=True, interval=0.5):
+    def wait_for(self, task_id,
+                 timeout=None, propagate=True, interval=0.5, no_ack=True,
+                 on_interval=None):
         """Wait for task and return its result.
 
         If the task raises an exception, this exception
@@ -200,6 +202,8 @@ class BaseBackend(object):
                 if propagate:
                     raise result
                 return result
+            if on_interval:
+                on_interval()
             # avoid hammering the CPU checking status.
             time.sleep(interval)
             time_elapsed += interval
@@ -430,7 +434,7 @@ class KeyValueStoreBackend(BaseBackend):
                         for i, value in enumerate(values)
                         if value is not None)
 
-    def get_many(self, task_ids, timeout=None, interval=0.5,
+    def get_many(self, task_ids, timeout=None, interval=0.5, no_ack=True,
                  READY_STATES=states.READY_STATES):
         interval = 0.5 if interval is None else interval
         ids = task_ids if isinstance(task_ids, set) else set(task_ids)

+ 48 - 18
celery/result.py

@@ -118,7 +118,8 @@ class AsyncResult(ResultBase):
                                 terminate=terminate, signal=signal,
                                 reply=wait, timeout=timeout)
 
-    def get(self, timeout=None, propagate=True, interval=0.5):
+    def get(self, timeout=None, propagate=True, interval=0.5, no_ack=True,
+            follow_parents=True):
         """Wait until task is ready, and return its result.
 
         .. warning::
@@ -133,6 +134,10 @@ class AsyncResult(ResultBase):
            retrieve the result.  Note that this does not have any effect
            when using the amqp result store backend, as it does not
            use polling.
+        :keyword no_ack: Enable amqp no ack (automatically acknowledge
+            message).  If this is :const:`False` then the message will
+            **not be acked**.
+        :keyword follow_parents: Reraise any exception raised by parent task.
 
         :raises celery.exceptions.TimeoutError: if `timeout` is not
             :const:`None` and the result does not arrive within `timeout`
@@ -143,15 +148,24 @@ class AsyncResult(ResultBase):
 
         """
         assert_will_not_block()
-        if propagate and self.parent:
-            for node in reversed(list(self._parents())):
-                node.get(propagate=True, timeout=timeout, interval=interval)
-
-        return self.backend.wait_for(self.id, timeout=timeout,
-                                     propagate=propagate,
-                                     interval=interval)
+        on_interval = None
+        if follow_parents and propagate and self.parent:
+            on_interval = self._maybe_reraise_parent_error
+            on_interval()
+
+        return self.backend.wait_for(
+            self.id, timeout=timeout,
+            propagate=propagate,
+            interval=interval,
+            on_interval=on_interval,
+            no_ack=no_ack,
+        )
     wait = get  # deprecated alias to :meth:`get`.
 
+    def _maybe_reraise_parent_error(self):
+        for node in reversed(list(self._parents())):
+            node.maybe_reraise()
+
     def _parents(self):
         node = self.parent
         while node:
@@ -238,6 +252,10 @@ class AsyncResult(ResultBase):
         """Returns :const:`True` if the task failed."""
         return self.state == states.FAILURE
 
+    def maybe_reraise(self):
+        if self.state in states.PROPAGATE_STATES:
+            raise self.result
+
     def build_graph(self, intermediate=False, formatter=None):
         graph = DependencyGraph(
             formatter=formatter or GraphFormatter(root=self.id, shape='oval'),
@@ -426,6 +444,10 @@ class ResultSet(ResultBase):
         """
         return any(result.failed() for result in self.results)
 
+    def maybe_reraise(self):
+        for result in self.results:
+            result.maybe_reraise()
+
     def waiting(self):
         """Are any of the tasks incomplete?
 
@@ -506,7 +528,8 @@ class ResultSet(ResultBase):
             if timeout and elapsed >= timeout:
                 raise TimeoutError('The operation timed out')
 
-    def get(self, timeout=None, propagate=True, interval=0.5, callback=None):
+    def get(self, timeout=None, propagate=True, interval=0.5,
+            callback=None, no_ack=True):
         """See :meth:`join`
 
         This is here for API compatibility with :class:`AsyncResult`,
@@ -516,9 +539,10 @@ class ResultSet(ResultBase):
         """
         return (self.join_native if self.supports_native_join else self.join)(
             timeout=timeout, propagate=propagate,
-            interval=interval, callback=callback)
+            interval=interval, callback=callback, no_ack=no_ack)
 
-    def join(self, timeout=None, propagate=True, interval=0.5, callback=None):
+    def join(self, timeout=None, propagate=True, interval=0.5,
+             callback=None, no_ack=True):
         """Gathers the results of all tasks as a list in order.
 
         .. note::
@@ -557,6 +581,10 @@ class ResultSet(ResultBase):
                            ``result = app.AsyncResult(task_id)`` (both will
                            take advantage of the backend cache anyway).
 
+        :keyword no_ack: Automatic message acknowledgement (Note that if this
+            is set to :const:`False` then the messages *will not be
+            acknowledged*).
+
         :raises celery.exceptions.TimeoutError: if ``timeout`` is not
             :const:`None` and the operation takes longer than ``timeout``
             seconds.
@@ -573,16 +601,17 @@ class ResultSet(ResultBase):
                 remaining = timeout - (monotonic() - time_start)
                 if remaining <= 0.0:
                     raise TimeoutError('join operation timed out')
-            value = result.get(timeout=remaining,
-                               propagate=propagate,
-                               interval=interval)
+            value = result.get(
+                timeout=remaining, propagate=propagate,
+                interval=interval, no_ack=no_ack,
+            )
             if callback:
                 callback(result.id, value)
             else:
                 results.append(value)
         return results
 
-    def iter_native(self, timeout=None, interval=0.5):
+    def iter_native(self, timeout=None, interval=0.5, no_ack=True):
         """Backend optimized version of :meth:`iterate`.
 
         .. versionadded:: 2.2
@@ -598,11 +627,12 @@ class ResultSet(ResultBase):
         if not results:
             return iter([])
         return results[0].backend.get_many(
-            set(r.id for r in results), timeout=timeout, interval=interval,
+            set(r.id for r in results),
+            timeout=timeout, interval=interval, no_ack=no_ack,
         )
 
     def join_native(self, timeout=None, propagate=True,
-                    interval=0.5, callback=None):
+                    interval=0.5, callback=None, no_ack=True):
         """Backend optimized version of :meth:`join`.
 
         .. versionadded:: 2.2
@@ -619,7 +649,7 @@ class ResultSet(ResultBase):
             (result.id, i) for i, result in enumerate(self.results)
         )
         acc = None if callback else [None for _ in range(len(self))]
-        for task_id, meta in self.iter_native(timeout, interval):
+        for task_id, meta in self.iter_native(timeout, interval, no_ack):
             value = meta['result']
             if propagate and meta['status'] in states.PROPAGATE_STATES:
                 raise value