Jelajahi Sumber

Merge pull request #2550 from xBeAsTx/master

on_message callback added
PMickael 10 tahun lalu
induk
melakukan
e58405fb34
3 mengubah file dengan 48 tambahan dan 9 penghapusan
  1. 5 3
      celery/backends/amqp.py
  2. 13 6
      celery/result.py
  3. 30 0
      celery/tests/backends/test_amqp.py

+ 5 - 3
celery/backends/amqp.py

@@ -231,7 +231,7 @@ class AMQPBackend(BaseBackend):
     def _many_bindings(self, ids):
         return [self._create_binding(task_id) for task_id in ids]
 
-    def get_many(self, task_ids, timeout=None, no_ack=True,
+    def get_many(self, task_ids, timeout=None, no_ack=True, on_message=None,
                  now=monotonic, getfields=itemgetter('status', 'task_id'),
                  READY_STATES=states.READY_STATES,
                  PROPAGATE_STATES=states.PROPAGATE_STATES, **kwargs):
@@ -254,15 +254,17 @@ class AMQPBackend(BaseBackend):
             push_cache = self._cache.__setitem__
             decode_result = self.meta_from_decoded
 
-            def on_message(message):
+            def _on_message(message):
                 body = decode_result(message.decode())
+                if on_message is not None:
+                    on_message(body)
                 state, uid = getfields(body)
                 if state in READY_STATES:
                     push_result(body) \
                         if uid in task_ids else push_cache(uid, body)
 
             bindings = self._many_bindings(task_ids)
-            with self.Consumer(channel, bindings, on_message=on_message,
+            with self.Consumer(channel, bindings, on_message=_on_message,
                                accept=self.accept, no_ack=no_ack):
                 wait = conn.drain_events
                 popleft = results.popleft

+ 13 - 6
celery/result.py

@@ -567,7 +567,7 @@ class ResultSet(ResultBase):
                 raise TimeoutError('The operation timed out')
 
     def get(self, timeout=None, propagate=True, interval=0.5,
-            callback=None, no_ack=True):
+            callback=None, no_ack=True, on_message=None):
         """See :meth:`join`
 
         This is here for API compatibility with :class:`AsyncResult`,
@@ -577,10 +577,10 @@ class ResultSet(ResultBase):
         """
         return (self.join_native if self.supports_native_join else self.join)(
             timeout=timeout, propagate=propagate,
-            interval=interval, callback=callback, no_ack=no_ack)
+            interval=interval, callback=callback, no_ack=no_ack, on_message=on_message)
 
     def join(self, timeout=None, propagate=True, interval=0.5,
-             callback=None, no_ack=True):
+             callback=None, no_ack=True, on_message=None):
         """Gathers the results of all tasks as a list in order.
 
         .. note::
@@ -632,6 +632,9 @@ class ResultSet(ResultBase):
         time_start = monotonic()
         remaining = None
 
+        if on_message is not None:
+            raise Exception('Your backend not suppored on_message callback')
+
         results = []
         for result in self.results:
             remaining = None
@@ -649,7 +652,8 @@ class ResultSet(ResultBase):
                 results.append(value)
         return results
 
-    def iter_native(self, timeout=None, interval=0.5, no_ack=True):
+    def iter_native(self, timeout=None, interval=0.5, no_ack=True,
+                    on_message=None):
         """Backend optimized version of :meth:`iterate`.
 
         .. versionadded:: 2.2
@@ -667,10 +671,12 @@ class ResultSet(ResultBase):
         return self.backend.get_many(
             set(r.id for r in results),
             timeout=timeout, interval=interval, no_ack=no_ack,
+            on_message=on_message,
         )
 
     def join_native(self, timeout=None, propagate=True,
-                    interval=0.5, callback=None, no_ack=True):
+                    interval=0.5, callback=None, no_ack=True,
+                    on_message=None):
         """Backend optimized version of :meth:`join`.
 
         .. versionadded:: 2.2
@@ -687,7 +693,8 @@ class ResultSet(ResultBase):
             result.id: i for i, result in enumerate(self.results)
         }
         acc = None if callback else [None for _ in range(len(self))]
-        for task_id, meta in self.iter_native(timeout, interval, no_ack):
+        for task_id, meta in self.iter_native(timeout, interval, no_ack,
+                                              on_message):
             value = meta['result']
             if propagate and meta['status'] in states.PROPAGATE_STATES:
                 raise value

+ 30 - 0
celery/tests/backends/test_amqp.py

@@ -294,6 +294,36 @@ class test_AMQPBackend(AppCase):
             b.store_result(tids[0], i, states.PENDING)
             list(b.get_many(tids, timeout=0.01))
 
+    def test_get_many_on_message(self):
+        b = self.create_backend(max_cached_results=10)
+
+        tids = []
+        for i in range(10):
+            tid = uuid()
+            b.store_result(tid, '', states.PENDING)
+            b.store_result(tid, 'comment_%i_1' % i, states.STARTED)
+            b.store_result(tid, 'comment_%i_2' % i, states.STARTED)
+            b.store_result(tid, 'final result %i' % i, states.SUCCESS)
+            tids.append(tid)
+
+
+        expected_messages = {}
+        for i, _tid in enumerate(tids):
+            expected_messages[_tid] = []
+            expected_messages[_tid].append( (states.PENDING, '') )
+            expected_messages[_tid].append( (states.STARTED, 'comment_%i_1' % i) )
+            expected_messages[_tid].append( (states.STARTED, 'comment_%i_2' % i) )
+            expected_messages[_tid].append( (states.SUCCESS, 'final result %i' % i) )
+
+        on_message_results = {}
+        def on_message(body):
+            if not body['task_id'] in on_message_results:
+                on_message_results[body['task_id']] = []
+            on_message_results[body['task_id']].append( (body['status'], body['result']) )
+
+        res = list(b.get_many(tids, timeout=1, on_message=on_message))
+        self.assertEqual(sorted(on_message_results), sorted(expected_messages))
+
     def test_get_many_raises_outer_block(self):
 
         class Backend(AMQPBackend):