Pārlūkot izejas kodu

[result][rpc] RPC backend get_task_meta needs to handle out of band messages

Ask Solem 9 gadi atpakaļ
vecāks
revīzija
d8d19de9d8

+ 28 - 7
celery/backends/amqp.py

@@ -85,7 +85,8 @@ class ResultConsumer(BaseResultConsumer):
             self._consumer.consume()
 
     def cancel_for(self, queue):
-        self._consumer.cancel_by_queue(queue.name)
+        if self._consumer:
+            self._consumer.cancel_by_queue(queue.name)
 
 
 class AMQPBackend(base.Backend, AsyncBackendMixin):
@@ -115,6 +116,7 @@ class AMQPBackend(base.Backend, AsyncBackendMixin):
         super(AMQPBackend, self).__init__(app, **kwargs)
         conf = self.app.conf
         self._connection = connection
+        self._out_of_band = {}
         self.persistent = self.prepare_persistent(persistent)
         self.delivery_mode = 2 if self.persistent else 1
         exchange = exchange or conf.result_exchange
@@ -191,7 +193,20 @@ class AMQPBackend(base.Backend, AsyncBackendMixin):
     def on_reply_declare(self, task_id):
         return [self._create_binding(task_id)]
 
+    def on_out_of_band_result(self, task_id, message):
+        if self.result_consumer:
+            self.result_consumer.on_out_of_band_result(message)
+        self._out_of_band[task_id] = message
+
     def get_task_meta(self, task_id, backlog_limit=1000):
+        try:
+            buffered = self._out_of_band.pop(task_id)
+        except KeyError:
+            pass
+        else:
+            payload = self._cache[task_id] = self.meta_from_decoded(
+                buffered.payload)
+            return payload
         # Polling and using basic_get
         with self.app.pool.acquire_channel(block=True) as (_, channel):
             binding = self._create_binding(task_id)(channel)
@@ -204,13 +219,19 @@ class AMQPBackend(base.Backend, AsyncBackendMixin):
                 )
                 if not acc:  # no more messages
                     break
-                if acc.payload['task_id'] == task_id:
+                try:
+                    message_task_id = acc.properties['correlation_id']
+                except (AttributeError, KeyError):
+                    message_task_id = acc.payload['task_id']
+                if message_task_id == task_id:
                     prev, latest = latest, acc
-                if prev:
-                    # backends are not expected to keep history,
-                    # so we delete everything except the most recent state.
-                    prev.ack()
-                    prev = None
+                    if prev:
+                        # backends are not expected to keep history,
+                        # so we delete everything except the most recent state.
+                        prev.ack()
+                        prev = None
+                else:
+                    self.on_out_of_band_result(message_task_id, acc)
             else:
                 raise self.BacklogLimitExceeded(task_id)
 

+ 3 - 0
celery/backends/async.py

@@ -190,6 +190,9 @@ class BaseResultConsumer(object):
         finally:
             self.on_message = prev_on_m
 
+    def on_out_of_band_result(self, message):
+        self.on_state_change(message.payload, message)
+
     def on_state_change(self, meta, message):
         if self.on_message:
             self.on_message(meta)

+ 1 - 0
celery/tests/backends/test_amqp.py

@@ -155,6 +155,7 @@ class test_AMQPBackend(AppCase):
             def __init__(self, **merge):
                 self.payload = dict({'status': states.STARTED,
                                      'result': None}, **merge)
+                self.properties = {'correlation_id': merge.get('task_id')}
                 self.body = pickle.dumps(self.payload)
                 self.content_type = 'application/x-python-serialize'
                 self.content_encoding = 'binary'