Browse Source

AMQP result backend now uses a connection pool (new setting: CELERY_AMQP_TASK_RESULT_CONNECTION_MAX). This means results can now be sent when running in eventlet/gevent

Ask Solem 14 years ago
parent
commit
060b266866
2 changed files with 87 additions and 71 deletions
  1. 1 0
      celery/app/defaults.py
  2. 86 71
      celery/backends/amqp.py

+ 1 - 0
celery/app/defaults.py

@@ -87,6 +87,7 @@ NAMESPACES = {
         "STORE_ERRORS_EVEN_IF_IGNORED": Option(False, type="bool"),
         "TASK_RESULT_EXPIRES": Option(timedelta(days=1), type="int"),
         "AMQP_TASK_RESULT_EXPIRES": Option(type="int"),
+        "AMQP_TASK_RESULT_CONNECTION_MAX": Option(type="int", default=1),
         "TASK_ERROR_WHITELIST": Option((), type="tuple"),
         "TASK_SERIALIZER": Option("pickle"),
         "TRACK_STARTED": Option(False, type="bool"),

+ 86 - 71
celery/backends/amqp.py

@@ -34,12 +34,11 @@ class AMQPBackend(BaseDictBackend):
 
     """
 
-    _connection = None
-    _channel = None
+    _pool = None
 
     def __init__(self, connection=None, exchange=None, exchange_type=None,
             persistent=None, serializer=None, auto_delete=True,
-            expires=None, **kwargs):
+            expires=None, connection_max=None, **kwargs):
         super(AMQPBackend, self).__init__(**kwargs)
         conf = self.app.conf
         self._connection = connection
@@ -68,6 +67,8 @@ class AMQPBackend(BaseDictBackend):
             # x-expires must be a signed-int, or long describing
             # the expiry time in milliseconds.
             self.queue_arguments["x-expires"] = int(self.expires * 1000.0)
+        self.connection_max = (connection_max or
+                               conf.CELERY_AMQP_TASK_RESULT_CONNECTION_MAX)
 
     def _create_binding(self, task_id):
         name = task_id.replace("-", "")
@@ -77,16 +78,16 @@ class AMQPBackend(BaseDictBackend):
                      durable=self.persistent,
                      auto_delete=self.auto_delete)
 
-    def _create_producer(self, task_id):
+    def _create_producer(self, task_id, channel):
         binding = self._create_binding(task_id)
-        binding(self.channel).declare()
+        binding(channel).declare()
 
-        return Producer(self.channel, exchange=self.exchange,
+        return Producer(channel, exchange=self.exchange,
                         routing_key=task_id.replace("-", ""),
                         serializer=self.serializer)
 
-    def _create_consumer(self, bindings):
-        return Consumer(self.channel, bindings, no_ack=True)
+    def _create_consumer(self, bindings, channel):
+        return Consumer(channel, bindings, no_ack=True)
 
     def store_result(self, task_id, result, status, traceback=None,
             max_retries=20, retry_delay=0.2):
@@ -99,17 +100,22 @@ class AMQPBackend(BaseDictBackend):
                 "traceback": traceback}
 
         for i in range(max_retries + 1):
+            conn = self.pool.acquire(block=True)
+            channel = conn.channel()
             try:
-                self._create_producer(task_id).publish(meta)
-            except Exception, exc:
-                if not max_retries:
-                    raise
-                self._channel = None
-                self._connection = None
-                warnings.warn(AMQResultWarning(
-                    "Error sending result %s: %r" % (task_id, exc)))
-                time.sleep(retry_delay)
-            break
+                try:
+                    self._create_producer(task_id, channel).publish(meta)
+                except Exception, exc:
+                    if not max_retries:
+                        raise
+                    warnings.warn(AMQResultWarning(
+                        "Error sending result %s: %r" % (task_id, exc)))
+                    time.sleep(retry_delay)
+                else:
+                    break
+            finally:
+                channel.close()
+                conn.release()
 
         return result
 
@@ -138,18 +144,24 @@ class AMQPBackend(BaseDictBackend):
             return self.wait_for(task_id, timeout, cache)
 
     def poll(self, task_id):
-        binding = self._create_binding(task_id)(self.channel)
-        result = binding.get()
-        if result:
-            binding.delete(if_unused=True, if_empty=True, nowait=True)
-            payload = self._cache[task_id] = result.payload
-            return payload
-        elif task_id in self._cache:
-            return self._cache[task_id]     # use previously received state.
-        return {"status": states.PENDING, "result": None}
+        conn = self.pool.acquire(block=True)
+        channel = conn.channel()
+        try:
+            binding = self._create_binding(task_id)(channel)
+            result = binding.get()
+            if result:
+                binding.delete(if_unused=True, if_empty=True, nowait=True)
+                payload = self._cache[task_id] = result.payload
+                return payload
+            elif task_id in self._cache:
+                return self._cache[task_id]     # use previously received state.
+            return {"status": states.PENDING, "result": None}
+        finally:
+            channel.close()
+            conn.release()
 
     def drain_events(self, consumer, timeout=None):
-        wait = self.connection.drain_events
+        wait = consumer.channel.connection.drain_events
         results = {}
 
         def callback(meta, message):
@@ -173,60 +185,63 @@ class AMQPBackend(BaseDictBackend):
         return results
 
     def consume(self, task_id, timeout=None):
-        binding = self._create_binding(task_id)
-        consumer = self._create_consumer(binding)
-        consumer.consume()
+        conn = self.pool.acquire(block=True)
+        channel = conn.channel()
         try:
-            return self.drain_events(consumer, timeout=timeout).values()[0]
+            binding = self._create_binding(task_id)
+            consumer = self._create_consumer(binding, channel)
+            consumer.consume()
+            try:
+                return self.drain_events(consumer, timeout=timeout).values()[0]
+            finally:
+                consumer.cancel()
         finally:
-            consumer.cancel()
+            channel.release()
+            conn.release()
 
     def get_many(self, task_ids, timeout=None):
-        bindings = [self._create_binding(task_id) for task_id in task_ids]
-        consumer = self._create_consumer(bindings)
-        consumer.consume()
-        ids = set(task_ids)
-        results = {}
-        cached_ids = set()
-        for task_id in ids:
-            try:
-                cached = self._cache[task_id]
-            except KeyError:
-                pass
-            else:
-                if cached["status"] in states.READY_STATES:
-                    results[task_id] = cached
-                    cached_ids.add(task_id)
-        ids ^= cached_ids
+        conn = self.pool.acquire(block=True)
+        channel = conn.channel()
         try:
-            while ids:
-                r = self.drain_events(consumer, timeout=timeout)
-                results.update(r)
-                ids ^= set(r.keys())
+            bindings = [self._create_binding(task_id) for task_id in task_ids]
+            consumer = self._create_consumer(bindings, channel)
+            consumer.consume()
+            ids = set(task_ids)
+            results = {}
+            cached_ids = set()
+            for task_id in ids:
+                try:
+                    cached = self._cache[task_id]
+                except KeyError:
+                    pass
+                else:
+                    if cached["status"] in states.READY_STATES:
+                        results[task_id] = cached
+                        cached_ids.add(task_id)
+            ids ^= cached_ids
+            try:
+                while ids:
+                    r = self.drain_events(consumer, timeout=timeout)
+                    results.update(r)
+                    ids ^= set(r.keys())
+            finally:
+                consumer.cancel()
         finally:
-            consumer.cancel()
+            channel.close()
+            conn.release()
 
-        return results
+    return results
 
     def close(self):
-        if self._channel is not None:
-            self._channel.close()
-            self._channel = None
-        if self._connection is not None:
-            self._connection.close()
-            self._connection = None
-
-    @property
-    def connection(self):
-        if not self._connection:
-            self._connection = self.app.broker_connection()
-        return self._connection
+        if self._pool is not None:
+            self._pool.close()
+            self._pool = None
 
     @property
-    def channel(self):
-        if not self._channel:
-            self._channel = self.connection.channel()
-        return self._channel
+    def pool(self):
+        if not self._pool:
+            self._pool = self.app.broker_connection().Pool(self.connection_max)
+        return self._pool
 
     def reload_task_result(self, task_id):
         raise NotImplementedError(