Преглед на файлове

Fixed chord as a chord callback for amqp result backend (also closes #1905)

Ask Solem преди 11 години
родител
ревизия
881786cffe
променени са 3 файла, в които са добавени 65 реда и са изтрити 27 реда
  1. 11 7
      celery/backends/amqp.py
  2. 6 2
      celery/backends/base.py
  3. 48 18
      celery/result.py

+ 11 - 7
celery/backends/amqp.py

@@ -141,6 +141,7 @@ class AMQPBackend(BaseBackend):
         return [self._create_binding(task_id)]
         return [self._create_binding(task_id)]
 
 
     def wait_for(self, task_id, timeout=None, cache=True, propagate=True,
     def wait_for(self, task_id, timeout=None, cache=True, propagate=True,
+                 no_ack=True, on_interval=None,
                  READY_STATES=states.READY_STATES,
                  READY_STATES=states.READY_STATES,
                  PROPAGATE_STATES=states.PROPAGATE_STATES,
                  PROPAGATE_STATES=states.PROPAGATE_STATES,
                  **kwargs):
                  **kwargs):
@@ -150,7 +151,8 @@ class AMQPBackend(BaseBackend):
             meta = cached_meta
             meta = cached_meta
         else:
         else:
             try:
             try:
-                meta = self.consume(task_id, timeout=timeout)
+                meta = self.consume(task_id, timeout=timeout, no_ack=no_ack,
+                                    on_interval=on_interval)
             except socket.timeout:
             except socket.timeout:
                 raise TimeoutError('The operation timed out.')
                 raise TimeoutError('The operation timed out.')
 
 
@@ -193,7 +195,7 @@ class AMQPBackend(BaseBackend):
     poll = get_task_meta  # XXX compat
     poll = get_task_meta  # XXX compat
 
 
     def drain_events(self, connection, consumer,
     def drain_events(self, connection, consumer,
-                     timeout=None, now=monotonic, wait=None):
+                     timeout=None, on_interval=None, now=monotonic, wait=None):
         wait = wait or connection.drain_events
         wait = wait or connection.drain_events
         results = {}
         results = {}
 
 
@@ -209,27 +211,29 @@ class AMQPBackend(BaseBackend):
             if timeout and now() - time_start >= timeout:
             if timeout and now() - time_start >= timeout:
                 raise socket.timeout()
                 raise socket.timeout()
             wait(timeout=timeout)
             wait(timeout=timeout)
+            if on_interval:
+                on_interval()
             if results:  # got event on the wanted channel.
             if results:  # got event on the wanted channel.
                 break
                 break
         self._cache.update(results)
         self._cache.update(results)
         return results
         return results
 
 
-    def consume(self, task_id, timeout=None):
+    def consume(self, task_id, timeout=None, no_ack=True, on_interval=None):
         wait = self.drain_events
         wait = self.drain_events
         with self.app.pool.acquire_channel(block=True) as (conn, channel):
         with self.app.pool.acquire_channel(block=True) as (conn, channel):
             binding = self._create_binding(task_id)
             binding = self._create_binding(task_id)
             with self.Consumer(channel, binding,
             with self.Consumer(channel, binding,
-                               no_ack=True, accept=self.accept) as consumer:
+                               no_ack=no_ack, accept=self.accept) as consumer:
                 while 1:
                 while 1:
                     try:
                     try:
-                        return wait(conn, consumer, timeout)[task_id]
+                        return wait(conn, consumer, timeout, on_interval)[task_id]
                     except KeyError:
                     except KeyError:
                         continue
                         continue
 
 
     def _many_bindings(self, ids):
     def _many_bindings(self, ids):
         return [self._create_binding(task_id) for task_id in ids]
         return [self._create_binding(task_id) for task_id in ids]
 
 
-    def get_many(self, task_ids, timeout=None,
+    def get_many(self, task_ids, timeout=None, no_ack=True,
                  now=monotonic, getfields=itemgetter('status', 'task_id'),
                  now=monotonic, getfields=itemgetter('status', 'task_id'),
                  READY_STATES=states.READY_STATES,
                  READY_STATES=states.READY_STATES,
                  PROPAGATE_STATES=states.PROPAGATE_STATES, **kwargs):
                  PROPAGATE_STATES=states.PROPAGATE_STATES, **kwargs):
@@ -263,7 +267,7 @@ class AMQPBackend(BaseBackend):
 
 
             bindings = self._many_bindings(task_ids)
             bindings = self._many_bindings(task_ids)
             with self.Consumer(channel, bindings, on_message=on_message,
             with self.Consumer(channel, bindings, on_message=on_message,
-                               accept=self.accept, no_ack=True):
+                               accept=self.accept, no_ack=no_ack):
                 wait = conn.drain_events
                 wait = conn.drain_events
                 popleft = results.popleft
                 popleft = results.popleft
                 while ids:
                 while ids:

+ 6 - 2
celery/backends/base.py

@@ -177,7 +177,9 @@ class BaseBackend(object):
                      content_encoding=self.content_encoding,
                      content_encoding=self.content_encoding,
                      accept=self.accept)
                      accept=self.accept)
 
 
-    def wait_for(self, task_id, timeout=None, propagate=True, interval=0.5):
+    def wait_for(self, task_id,
+                 timeout=None, propagate=True, interval=0.5, no_ack=True,
+                 on_interval=None):
         """Wait for task and return its result.
         """Wait for task and return its result.
 
 
         If the task raises an exception, this exception
         If the task raises an exception, this exception
@@ -200,6 +202,8 @@ class BaseBackend(object):
                 if propagate:
                 if propagate:
                     raise result
                     raise result
                 return result
                 return result
+            if on_interval:
+                on_interval()
             # avoid hammering the CPU checking status.
             # avoid hammering the CPU checking status.
             time.sleep(interval)
             time.sleep(interval)
             time_elapsed += interval
             time_elapsed += interval
@@ -430,7 +434,7 @@ class KeyValueStoreBackend(BaseBackend):
                         for i, value in enumerate(values)
                         for i, value in enumerate(values)
                         if value is not None)
                         if value is not None)
 
 
-    def get_many(self, task_ids, timeout=None, interval=0.5,
+    def get_many(self, task_ids, timeout=None, interval=0.5, no_ack=True,
                  READY_STATES=states.READY_STATES):
                  READY_STATES=states.READY_STATES):
         interval = 0.5 if interval is None else interval
         interval = 0.5 if interval is None else interval
         ids = task_ids if isinstance(task_ids, set) else set(task_ids)
         ids = task_ids if isinstance(task_ids, set) else set(task_ids)

+ 48 - 18
celery/result.py

@@ -118,7 +118,8 @@ class AsyncResult(ResultBase):
                                 terminate=terminate, signal=signal,
                                 terminate=terminate, signal=signal,
                                 reply=wait, timeout=timeout)
                                 reply=wait, timeout=timeout)
 
 
-    def get(self, timeout=None, propagate=True, interval=0.5):
+    def get(self, timeout=None, propagate=True, interval=0.5, no_ack=True,
+            follow_parents=True):
         """Wait until task is ready, and return its result.
         """Wait until task is ready, and return its result.
 
 
         .. warning::
         .. warning::
@@ -133,6 +134,10 @@ class AsyncResult(ResultBase):
            retrieve the result.  Note that this does not have any effect
            retrieve the result.  Note that this does not have any effect
            when using the amqp result store backend, as it does not
            when using the amqp result store backend, as it does not
            use polling.
            use polling.
+        :keyword no_ack: Enable amqp no ack (automatically acknowledge
+            message).  If this is :const:`False` then the message will
+            **not be acked**.
+        :keyword follow_parents: Reraise any exception raised by parent task.
 
 
         :raises celery.exceptions.TimeoutError: if `timeout` is not
         :raises celery.exceptions.TimeoutError: if `timeout` is not
             :const:`None` and the result does not arrive within `timeout`
             :const:`None` and the result does not arrive within `timeout`
@@ -143,15 +148,24 @@ class AsyncResult(ResultBase):
 
 
         """
         """
         assert_will_not_block()
         assert_will_not_block()
-        if propagate and self.parent:
-            for node in reversed(list(self._parents())):
-                node.get(propagate=True, timeout=timeout, interval=interval)
-
-        return self.backend.wait_for(self.id, timeout=timeout,
-                                     propagate=propagate,
-                                     interval=interval)
+        on_interval = None
+        if follow_parents and propagate and self.parent:
+            on_interval = self._maybe_reraise_parent_error
+            on_interval()
+
+        return self.backend.wait_for(
+            self.id, timeout=timeout,
+            propagate=propagate,
+            interval=interval,
+            on_interval=on_interval,
+            no_ack=no_ack,
+        )
     wait = get  # deprecated alias to :meth:`get`.
     wait = get  # deprecated alias to :meth:`get`.
 
 
+    def _maybe_reraise_parent_error(self):
+        for node in reversed(list(self._parents())):
+            node.maybe_reraise()
+
     def _parents(self):
     def _parents(self):
         node = self.parent
         node = self.parent
         while node:
         while node:
@@ -238,6 +252,10 @@ class AsyncResult(ResultBase):
         """Returns :const:`True` if the task failed."""
         """Returns :const:`True` if the task failed."""
         return self.state == states.FAILURE
         return self.state == states.FAILURE
 
 
+    def maybe_reraise(self):
+        if self.state in states.PROPAGATE_STATES:
+            raise self.result
+
     def build_graph(self, intermediate=False, formatter=None):
     def build_graph(self, intermediate=False, formatter=None):
         graph = DependencyGraph(
         graph = DependencyGraph(
             formatter=formatter or GraphFormatter(root=self.id, shape='oval'),
             formatter=formatter or GraphFormatter(root=self.id, shape='oval'),
@@ -426,6 +444,10 @@ class ResultSet(ResultBase):
         """
         """
         return any(result.failed() for result in self.results)
         return any(result.failed() for result in self.results)
 
 
+    def maybe_reraise(self):
+        for result in self.results:
+            result.maybe_reraise()
+
     def waiting(self):
     def waiting(self):
         """Are any of the tasks incomplete?
         """Are any of the tasks incomplete?
 
 
@@ -506,7 +528,8 @@ class ResultSet(ResultBase):
             if timeout and elapsed >= timeout:
             if timeout and elapsed >= timeout:
                 raise TimeoutError('The operation timed out')
                 raise TimeoutError('The operation timed out')
 
 
-    def get(self, timeout=None, propagate=True, interval=0.5, callback=None):
+    def get(self, timeout=None, propagate=True, interval=0.5,
+            callback=None, no_ack=True):
         """See :meth:`join`
         """See :meth:`join`
 
 
         This is here for API compatibility with :class:`AsyncResult`,
         This is here for API compatibility with :class:`AsyncResult`,
@@ -516,9 +539,10 @@ class ResultSet(ResultBase):
         """
         """
         return (self.join_native if self.supports_native_join else self.join)(
         return (self.join_native if self.supports_native_join else self.join)(
             timeout=timeout, propagate=propagate,
             timeout=timeout, propagate=propagate,
-            interval=interval, callback=callback)
+            interval=interval, callback=callback, no_ack=no_ack)
 
 
-    def join(self, timeout=None, propagate=True, interval=0.5, callback=None):
+    def join(self, timeout=None, propagate=True, interval=0.5,
+             callback=None, no_ack=True):
         """Gathers the results of all tasks as a list in order.
         """Gathers the results of all tasks as a list in order.
 
 
         .. note::
         .. note::
@@ -557,6 +581,10 @@ class ResultSet(ResultBase):
                            ``result = app.AsyncResult(task_id)`` (both will
                            ``result = app.AsyncResult(task_id)`` (both will
                            take advantage of the backend cache anyway).
                            take advantage of the backend cache anyway).
 
 
+        :keyword no_ack: Automatic message acknowledgement (Note that if this
+            is set to :const:`False` then the messages *will not be
+            acknowledged*).
+
         :raises celery.exceptions.TimeoutError: if ``timeout`` is not
         :raises celery.exceptions.TimeoutError: if ``timeout`` is not
             :const:`None` and the operation takes longer than ``timeout``
             :const:`None` and the operation takes longer than ``timeout``
             seconds.
             seconds.
@@ -573,16 +601,17 @@ class ResultSet(ResultBase):
                 remaining = timeout - (monotonic() - time_start)
                 remaining = timeout - (monotonic() - time_start)
                 if remaining <= 0.0:
                 if remaining <= 0.0:
                     raise TimeoutError('join operation timed out')
                     raise TimeoutError('join operation timed out')
-            value = result.get(timeout=remaining,
-                               propagate=propagate,
-                               interval=interval)
+            value = result.get(
+                timeout=remaining, propagate=propagate,
+                interval=interval, no_ack=no_ack,
+            )
             if callback:
             if callback:
                 callback(result.id, value)
                 callback(result.id, value)
             else:
             else:
                 results.append(value)
                 results.append(value)
         return results
         return results
 
 
-    def iter_native(self, timeout=None, interval=0.5):
+    def iter_native(self, timeout=None, interval=0.5, no_ack=True):
         """Backend optimized version of :meth:`iterate`.
         """Backend optimized version of :meth:`iterate`.
 
 
         .. versionadded:: 2.2
         .. versionadded:: 2.2
@@ -598,11 +627,12 @@ class ResultSet(ResultBase):
         if not results:
         if not results:
             return iter([])
             return iter([])
         return results[0].backend.get_many(
         return results[0].backend.get_many(
-            set(r.id for r in results), timeout=timeout, interval=interval,
+            set(r.id for r in results),
+            timeout=timeout, interval=interval, no_ack=no_ack,
         )
         )
 
 
     def join_native(self, timeout=None, propagate=True,
     def join_native(self, timeout=None, propagate=True,
-                    interval=0.5, callback=None):
+                    interval=0.5, callback=None, no_ack=True):
         """Backend optimized version of :meth:`join`.
         """Backend optimized version of :meth:`join`.
 
 
         .. versionadded:: 2.2
         .. versionadded:: 2.2
@@ -619,7 +649,7 @@ class ResultSet(ResultBase):
             (result.id, i) for i, result in enumerate(self.results)
             (result.id, i) for i, result in enumerate(self.results)
         )
         )
         acc = None if callback else [None for _ in range(len(self))]
         acc = None if callback else [None for _ in range(len(self))]
-        for task_id, meta in self.iter_native(timeout, interval):
+        for task_id, meta in self.iter_native(timeout, interval, no_ack):
             value = meta['result']
             value = meta['result']
             if propagate and meta['status'] in states.PROPAGATE_STATES:
             if propagate and meta['status'] in states.PROPAGATE_STATES:
                 raise value
                 raise value