瀏覽代碼

on_message callback added

Ilya Georgievsky 10 年之前
父節點
當前提交
27f85e786c
共有 2 個文件被更改,包括 14 次插入8 次删除
  1. 5 3
      celery/backends/amqp.py
  2. 9 5
      celery/result.py

+ 5 - 3
celery/backends/amqp.py

@@ -231,7 +231,7 @@ class AMQPBackend(BaseBackend):
     def _many_bindings(self, ids):
         return [self._create_binding(task_id) for task_id in ids]
 
-    def get_many(self, task_ids, timeout=None, no_ack=True,
+    def get_many(self, task_ids, timeout=None, no_ack=True, on_message=None,
                  now=monotonic, getfields=itemgetter('status', 'task_id'),
                  READY_STATES=states.READY_STATES,
                  PROPAGATE_STATES=states.PROPAGATE_STATES, **kwargs):
@@ -254,15 +254,17 @@ class AMQPBackend(BaseBackend):
             push_cache = self._cache.__setitem__
             decode_result = self.meta_from_decoded
 
-            def on_message(message):
+            def _on_message(message):
                 body = decode_result(message.decode())
+                if on_message is not None:
+                    on_message(body)
                 state, uid = getfields(body)
                 if state in READY_STATES:
                     push_result(body) \
                         if uid in task_ids else push_cache(uid, body)
 
             bindings = self._many_bindings(task_ids)
-            with self.Consumer(channel, bindings, on_message=on_message,
+            with self.Consumer(channel, bindings, on_message=_on_message,
                                accept=self.accept, no_ack=no_ack):
                 wait = conn.drain_events
                 popleft = results.popleft

+ 9 - 5
celery/result.py

@@ -567,7 +567,7 @@ class ResultSet(ResultBase):
                 raise TimeoutError('The operation timed out')
 
     def get(self, timeout=None, propagate=True, interval=0.5,
-            callback=None, no_ack=True):
+            callback=None, no_ack=True, on_message=None):
         """See :meth:`join`
 
         This is here for API compatibility with :class:`AsyncResult`,
@@ -577,7 +577,7 @@ class ResultSet(ResultBase):
         """
         return (self.join_native if self.supports_native_join else self.join)(
             timeout=timeout, propagate=propagate,
-            interval=interval, callback=callback, no_ack=no_ack)
+            interval=interval, callback=callback, no_ack=no_ack, on_message=on_message)
 
     def join(self, timeout=None, propagate=True, interval=0.5,
              callback=None, no_ack=True):
@@ -649,7 +649,8 @@ class ResultSet(ResultBase):
                 results.append(value)
         return results
 
-    def iter_native(self, timeout=None, interval=0.5, no_ack=True):
+    def iter_native(self, timeout=None, interval=0.5, no_ack=True,
+                    on_message=None):
         """Backend optimized version of :meth:`iterate`.
 
         .. versionadded:: 2.2
@@ -667,10 +668,12 @@ class ResultSet(ResultBase):
         return self.backend.get_many(
             set(r.id for r in results),
             timeout=timeout, interval=interval, no_ack=no_ack,
+            on_message=on_message,
         )
 
     def join_native(self, timeout=None, propagate=True,
-                    interval=0.5, callback=None, no_ack=True):
+                    interval=0.5, callback=None, no_ack=True,
+                    on_message=None):
         """Backend optimized version of :meth:`join`.
 
         .. versionadded:: 2.2
@@ -687,7 +690,8 @@ class ResultSet(ResultBase):
             result.id: i for i, result in enumerate(self.results)
         }
         acc = None if callback else [None for _ in range(len(self))]
-        for task_id, meta in self.iter_native(timeout, interval, no_ack):
+        for task_id, meta in self.iter_native(timeout, interval, no_ack,
+                                              on_message):
             value = meta['result']
             if propagate and meta['status'] in states.PROPAGATE_STATES:
                 raise value