Przeglądaj źródła

amqrpc result backend apparently working

Ask Solem 13 lat temu
rodzic
commit
7f933deee0
3 zmienionych plików z 56 dodań i 21 usunięć
  1. 35 16
      celery/backends/amqp.py
  2. 19 3
      celery/backends/amqrpc.py
  3. 2 2
      celery/backends/base.py

+ 35 - 16
celery/backends/amqp.py

@@ -15,6 +15,8 @@ import socket
 import threading
 import time
 
+from collections import deque
+
 from kombu.entity import Exchange, Queue
 from kombu.messaging import Consumer, Producer
 
@@ -114,8 +116,6 @@ class AMQPBackend(BaseDictBackend):
         """Send task return value and status."""
         with self.mutex:
             with self.app.amqp.producer_pool.acquire(block=True) as pub:
-                print("PUBLISH TO exchange=%r rkey=%r" % (self.exchange,
-                    self._routing_key(task_id)))
                 pub.publish({'task_id': task_id, 'status': status,
                              'result': self.encode_result(result, status),
                              'traceback': traceback,
@@ -175,14 +175,14 @@ class AMQPBackend(BaseDictBackend):
                     return {'status': states.PENDING, 'result': None}
     poll = get_task_meta  # XXX compat
 
-    def drain_events(self, connection, consumer, timeout=None, now=time.time):
-        wait = connection.drain_events
+    def drain_events(self, connection, consumer, timeout=None, now=time.time,
+            wait=None):
+        wait = wait or connection.drain_events
         results = {}
 
         def callback(meta, message):
             if meta['status'] in states.READY_STATES:
-                uuid = repair_uuid(message.delivery_info['routing_key'])
-                results[uuid] = meta
+                results[meta['task_id']] = meta
 
         consumer.callbacks[:] = [callback]
         time_start = now()
@@ -198,12 +198,20 @@ class AMQPBackend(BaseDictBackend):
         return results
 
     def consume(self, task_id, timeout=None):
+        wait = self.drain_events
         with self.app.pool.acquire_channel(block=True) as (conn, channel):
             binding = self._create_binding(task_id)
             with self.Consumer(channel, binding, no_ack=True) as consumer:
-                return self.drain_events(conn, consumer, timeout).values()[0]
+                while 1:
+                    try:
+                        return wait(conn, consumer, timeout)[task_id]
+                    except KeyError:
+                        continue
+
+    def _many_bindings(self, ids):
+        return [self._create_binding(task_id) for task_id in ids]
 
-    def get_many(self, task_ids, timeout=None, **kwargs):
+    def get_many(self, task_ids, timeout=None, now=time.time, **kwargs):
         with self.app.pool.acquire_channel(block=True) as (conn, channel):
             ids = set(task_ids)
             cached_ids = set()
@@ -216,15 +224,26 @@ class AMQPBackend(BaseDictBackend):
                     if cached['status'] in states.READY_STATES:
                         yield task_id, cached
                         cached_ids.add(task_id)
-            ids ^= cached_ids
-
-            bindings = [self._create_binding(task_id) for task_id in task_ids]
-            with self.Consumer(channel, bindings, no_ack=True) as consumer:
+            ids.difference_update(cached_ids)
+            results = deque()
+
+            def callback(meta, message):
+                if meta['status'] in states.READY_STATES:
+                    results.append(meta)
+
+            bindings = self._many_bindings(task_id)
+            with self.Consumer(channel, bindings, callbacks=[callback],
+                    no_ack=True):
+                wait = conn.drain_events
+                popleft = results.popleft
                 while ids:
-                    r = self.drain_events(conn, consumer, timeout)
-                    ids ^= set(r)
-                    for ready_id, ready_meta in r.iteritems():
-                        yield ready_id, ready_meta
+                    wait(timeout=timeout)
+                    while results:
+                        meta = popleft()
+                        task_id = meta['task_id']
+                        ids.discard(task_id)
+                        self._cache[task_id] = meta
+                        yield task_id, meta
 
     def reload_task_result(self, task_id):
         raise NotImplementedError(

+ 19 - 3
celery/backends/amqrpc.py

@@ -1,10 +1,21 @@
+# -*- coding: utf-8 -*-
+"""
+    celery.backends.amqrpc
+    ~~~~~~~~~~~~~~~~~~~~~~
+
+    RPC-style result backend, using reply-to and one queue per client.
+
+"""
 from __future__ import absolute_import
+from __future__ import with_statement
 
+import kombu
 import os
 import uuid
 
 from threading import local
 
+from kombu.common import maybe_declare
 from celery.backends import amqp
 
 try:
@@ -21,19 +32,24 @@ _nodeid = uuid.getnode()
 class AMQRPCBackend(amqp.AMQPBackend):
     _tls = local()
 
+    class Consumer(kombu.Consumer):
+        auto_declare = False
+
     def _create_exchange(self, name, type='direct', persistent=False):
         return self.Exchange('c.amqrpc', type=type, delivery_mode=1,
                 durable=False, auto_delete=False)
 
     def on_task_apply(self, task_id):
         with self.app.pool.acquire_channel(block=True) as (conn, channel):
-            self.binding(channel).declare()
+            maybe_declare(self.binding(channel), retry=True)
             return {'reply_to': self.oid}
 
     def _create_binding(self, task_id):
-        print("BINDING: %r" % (self.binding, ))
         return self.binding
 
+    def _many_bindings(self, ids):
+        return [self.binding]
+
     def _routing_key(self, task_id):
         from celery import current_task
         return current_task.request.reply_to
@@ -41,7 +57,7 @@ class AMQRPCBackend(amqp.AMQPBackend):
     @property
     def binding(self):
         return self.Queue(self.oid, self.exchange, self.oid,
-                          durable=False, auto_delete=True)
+                          durable=False, auto_delete=False)
 
     @property
     def oid(self):

+ 2 - 2
celery/backends/base.py

@@ -403,14 +403,14 @@ class KeyValueStoreBackend(BaseDictBackend):
                     yield bytes_to_str(task_id), cached
                     cached_ids.add(task_id)
 
-        ids ^= cached_ids
+        ids.difference_update(cached_ids)
         iterations = 0
         while ids:
             keys = list(ids)
             r = self._mget_to_results(self.mget([self.get_key_for_task(k)
                                                     for k in keys]), keys)
             self._cache.update(r)
-            ids ^= set(map(bytes_to_str, r))
+            ids.difference_update(set(map(bytes_to_str, r)))
             for key, value in r.iteritems():
                 yield bytes_to_str(key), value
             if timeout and iterations * interval >= timeout: