Pārlūkot izejas kodu

celery.backends.amqp: Finally with with-statements

Ask Solem 14 gadi atpakaļ
vecāks
revīzija
7a953e3431
1 mainītis faili ar 18 papildinājumiem un 53 dzēšanām
  1. 18 53
      celery/backends/amqp.py

+ 18 - 53
celery/backends/amqp.py

@@ -1,4 +1,6 @@
 # -*- coding: utf-8 -*-
+from __future__ import absolute_import, with_statement
+
 import socket
 import threading
 import time
@@ -42,9 +44,8 @@ class AMQPBackend(BaseDictBackend):
         conf = self.app.conf
         self._connection = connection
         self.queue_arguments = {}
-        if persistent is None:
-            persistent = conf.CELERY_RESULT_PERSISTENT
-        self.persistent = persistent
+        self.persistent = (conf.CELERY_RESULT_PERSISTENT if persistent is None
+                                                         else persistent)
         delivery_mode = persistent and "persistent" or "transient"
         exchange = exchange or conf.CELERY_RESULT_EXCHANGE
         exchange_type = exchange_type or conf.CELERY_RESULT_EXCHANGE_TYPE
@@ -55,15 +56,14 @@ class AMQPBackend(BaseDictBackend):
                                       auto_delete=auto_delete)
         self.serializer = serializer or conf.CELERY_RESULT_SERIALIZER
         self.auto_delete = auto_delete
-        self.expires = expires
-        if self.expires is None:
-            self.expires = conf.CELERY_AMQP_TASK_RESULT_EXPIRES
+        self.expires = (conf.CELERY_AMQP_TASK_RESULT_EXPIRES if expires is None
+                                                             else expires)
         if isinstance(self.expires, timedelta):
             self.expires = timeutils.timedelta_seconds(self.expires)
         if self.expires is not None:
             self.expires = int(self.expires)
-            # requires RabbitMQ 2.1.0 or higher.
-            self.queue_arguments["x-expires"] = int(self.expires * 1000.0)
+            # x-expires requires RabbitMQ 2.1.0 or higher.
+            self.queue_arguments["x-expires"] = self.expires * 1000.0
         self.connection_max = (connection_max or
                                conf.CELERY_AMQP_TASK_RESULT_CONNECTION_MAX)
         self.mutex = threading.Lock()
@@ -104,10 +104,8 @@ class AMQPBackend(BaseDictBackend):
             max_retries=20, interval_start=0, interval_step=1,
             interval_max=1):
         """Send task return value and status."""
-        self.mutex.acquire()
-        try:
-            conn = self.app.pool.acquire(block=True)
-            try:
+        with self.mutex:
+            with self.app.pool.acquire(block=True) as conn:
 
                 def errback(error, delay):
                     conn._result_producer_chan = None
@@ -123,11 +121,6 @@ class AMQPBackend(BaseDictBackend):
                 send(conn, task_id, {"task_id": task_id, "status": status,
                                 "result": self.encode_result(result, status),
                                 "traceback": traceback})
-            finally:
-                conn.release()
-        finally:
-            self.mutex.release()
-
         return result
 
     def get_task_meta(self, task_id, cache=True):
@@ -156,9 +149,7 @@ class AMQPBackend(BaseDictBackend):
             return self.wait_for(task_id, timeout, cache)
 
     def poll(self, task_id, backlog_limit=100):
-        conn = self.app.pool.acquire(block=True)
-        channel = conn.channel()
-        try:
+        with self.app.pool.acquire_channel(block=True) as (_, channel):
             binding = self._create_binding(task_id)(channel)
             binding.declare()
             latest, acc = None, None
@@ -174,9 +165,6 @@ class AMQPBackend(BaseDictBackend):
             elif task_id in self._cache:  # use previously received state.
                 return self._cache[task_id]
             return {"status": states.PENDING, "result": None}
-        finally:
-            channel.close()
-            conn.release()
 
     def drain_events(self, connection, consumer, timeout=None, now=time.time):
         wait = connection.drain_events
@@ -186,9 +174,10 @@ class AMQPBackend(BaseDictBackend):
             if meta["status"] in states.READY_STATES:
                 uuid = repair_uuid(message.delivery_info["routing_key"])
                 results[uuid] = meta
-        consumer.register_callback(callback)
 
+        consumer.callbacks[:] = [callback]
         time_start = now()
+
         while 1:
             # Total time spent may exceed a single call to wait()
             if timeout and now() - time_start >= timeout:
@@ -200,24 +189,13 @@ class AMQPBackend(BaseDictBackend):
         return results
 
     def consume(self, task_id, timeout=None):
-        conn = self.app.pool.acquire(block=True)
-        channel = conn.channel()
-        try:
+        with self.app.pool.acquire_channel(block=True) as (conn, channel):
             binding = self._create_binding(task_id)
-            consumer = self._create_consumer(binding, channel)
-            consumer.consume()
-            try:
+            with self._create_consumer(binding, channel) as consumer:
                 return self.drain_events(conn, consumer, timeout).values()[0]
-            finally:
-                consumer.cancel()
-        finally:
-            channel.close()
-            conn.release()
 
     def get_many(self, task_ids, timeout=None):
-        conn = self.app.pool.acquire(block=True)
-        channel = conn.channel()
-        try:
+        with self.app.pool.acquire_channel(block=True) as (conn, channel):
             ids = set(task_ids)
             cached_ids = set()
             for task_id in ids:
@@ -229,28 +207,15 @@ class AMQPBackend(BaseDictBackend):
                     if cached["status"] in states.READY_STATES:
                         yield task_id, cached
                         cached_ids.add(task_id)
-            ids ^= cached_ids
 
+            ids ^= cached_ids
             bindings = [self._create_binding(task_id) for task_id in task_ids]
-            consumer = self._create_consumer(bindings, channel)
-            consumer.consume()
-            try:
+            with self._create_consumer(bindings, channel) as consumer:
                 while ids:
                     r = self.drain_events(conn, consumer, timeout)
                     ids ^= set(r.keys())
                     for ready_id, ready_meta in r.items():
                         yield ready_id, ready_meta
-            except:   # ☹ Py2.4 — Cannot yield inside try: finally: block
-                consumer.cancel()
-                raise
-            consumer.cancel()
-
-        except:  # … ☹
-            channel.close()
-            conn.release()
-            raise
-        channel.close()
-        conn.release()
 
     def reload_task_result(self, task_id):
         raise NotImplementedError(