Browse Source

Cosmetics

Ask Solem 14 years ago
parent
commit
bea1ab05e6
2 changed files with 19 additions and 41 deletions
  1. 18 39
      celery/backends/amqp.py
  2. 1 2
      celery/backends/base.py

+ 18 - 39
celery/backends/amqp.py

@@ -22,14 +22,7 @@ def repair_uuid(s):
 
 
 class AMQPBackend(BaseDictBackend):
-    """AMQP backend. Publish results by sending messages to the broker
-    using the task id as routing key.
-
-    **NOTE:** Results published using this backend is read-once only.
-    After the result has been read, the result is deleted. (however, it's
-    still cached locally by the backend instance).
-
-    """
+    """Publishes results by sending messages."""
     Exchange = Exchange
     Queue = Queue
     Consumer = Consumer
@@ -44,12 +37,12 @@ class AMQPBackend(BaseDictBackend):
         conf = self.app.conf
         self._connection = connection
         self.queue_arguments = {}
-        exchange = exchange or conf.CELERY_RESULT_EXCHANGE
-        exchange_type = exchange_type or conf.CELERY_RESULT_EXCHANGE_TYPE
         if persistent is None:
             persistent = conf.CELERY_RESULT_PERSISTENT
         self.persistent = persistent
         delivery_mode = persistent and "persistent" or "transient"
+        exchange = exchange or conf.CELERY_RESULT_EXCHANGE
+        exchange_type = exchange_type or conf.CELERY_RESULT_EXCHANGE_TYPE
         self.exchange = self.Exchange(name=exchange,
                                       type=exchange_type,
                                       delivery_mode=delivery_mode,
@@ -64,9 +57,7 @@ class AMQPBackend(BaseDictBackend):
             self.expires = timeutils.timedelta_seconds(self.expires)
         if self.expires is not None:
             self.expires = int(self.expires)
-            # WARNING: Requires RabbitMQ 2.1.0 or higher.
-            # x-expires must be a signed-int, or long describing
-            # the expiry time in milliseconds.
+            # requires RabbitMQ 2.1.0 or higher.
             self.queue_arguments["x-expires"] = int(self.expires * 1000.0)
         self.connection_max = (connection_max or
                                conf.CELERY_AMQP_TASK_RESULT_CONNECTION_MAX)
@@ -81,9 +72,7 @@ class AMQPBackend(BaseDictBackend):
                           queue_arguments=self.queue_arguments)
 
     def _create_producer(self, task_id, channel):
-        binding = self._create_binding(task_id)
-        binding(channel).declare()
-
+        self._create_binding(task_id)(channel).declare()
         return self.Producer(channel, exchange=self.exchange,
                              routing_key=task_id.replace("-", ""),
                              serializer=self.serializer)
@@ -92,6 +81,7 @@ class AMQPBackend(BaseDictBackend):
         return self.Consumer(channel, bindings, no_ack=True)
 
     def _publish_result(self, connection, task_id, meta):
+        # cache single channel
         if hasattr(connection, "_result_producer_chan") and \
                 connection._result_producer_chan is not None and \
                 connection._result_producer_chan.connection is not None:
@@ -111,20 +101,16 @@ class AMQPBackend(BaseDictBackend):
             max_retries=20, interval_start=0, interval_step=1,
             interval_max=1):
         """Send task return value and status."""
-        result = self.encode_result(result, status)
-
-        meta = {"task_id": task_id,
-                "result": result,
-                "status": status,
-                "traceback": traceback}
-
         conn = self.pool.acquire(block=True)
         try:
-            conn.ensure(self, self._publish_result,
+            send = conn.ensure(self, self._publish_result,
                         max_retries=max_retries,
                         interval_start=interval_start,
                         interval_step=interval_step,
-                        interval_max=interval_max)(conn, task_id, meta)
+                        interval_max=interval_max)
+            send(conn, task_id, {"task_id": task_id, "status": status,
+                                 "result": self.encode_result(result, status),
+                                 "traceback": traceback})
         finally:
             conn.release()
 
@@ -133,13 +119,11 @@ class AMQPBackend(BaseDictBackend):
     def get_task_meta(self, task_id, cache=True):
         if cache and task_id in self._cache:
             return self._cache[task_id]
-
         return self.poll(task_id)
 
     def wait_for(self, task_id, timeout=None, cache=True, propagate=True,
             **kwargs):
         cached_meta = self._cache.get(task_id)
-
         if cache and cached_meta and \
                 cached_meta["status"] in states.READY_STATES:
             meta = cached_meta
@@ -169,15 +153,14 @@ class AMQPBackend(BaseDictBackend):
             if result:
                 payload = self._cache[task_id] = result.payload
                 return payload
-            elif task_id in self._cache:
-                # use previously received state.
+            elif task_id in self._cache:  # use previously received state.
                 return self._cache[task_id]
             return {"status": states.PENDING, "result": None}
         finally:
             channel.close()
             conn.release()
 
-    def drain_events(self, connection, consumer, timeout=None):
+    def drain_events(self, connection, consumer, timeout=None, now=time.time):
         wait = connection.drain_events
         results = {}
 
@@ -185,19 +168,16 @@ class AMQPBackend(BaseDictBackend):
             if meta["status"] in states.READY_STATES:
                 uuid = repair_uuid(message.delivery_info["routing_key"])
                 results[uuid] = meta
-
         consumer.register_callback(callback)
 
-        time_start = time.time()
+        time_start = now()
         while 1:
             # Total time spent may exceed a single call to wait()
-            if timeout and time.time() - time_start >= timeout:
+            if timeout and now() - time_start >= timeout:
                 raise socket.timeout()
             wait(timeout=timeout)
-            if results:
-                # Got event on the wanted channel.
+            if results:  # got event on the wanted channel.
                 break
-
         self._cache.update(results)
         return results
 
@@ -209,8 +189,7 @@ class AMQPBackend(BaseDictBackend):
             consumer = self._create_consumer(binding, channel)
             consumer.consume()
             try:
-                return self.drain_events(conn, consumer,
-                                         timeout=timeout).values()[0]
+                return self.drain_events(conn, consumer, timeout).values()[0]
             finally:
                 consumer.cancel()
         finally:
@@ -239,7 +218,7 @@ class AMQPBackend(BaseDictBackend):
             consumer.consume()
             try:
                 while ids:
-                    r = self.drain_events(conn, consumer, timeout=timeout)
+                    r = self.drain_events(conn, consumer, timeout)
                     ids ^= set(r.keys())
                     for ready_id, ready_meta in r.items():
                         yield ready_id, ready_meta

+ 1 - 2
celery/backends/base.py

@@ -9,8 +9,7 @@ from celery.datastructures import LocalCache
 
 
 class BaseBackend(object):
-    """The base backend class. All backends should inherit from this."""
-
+    """Base backend class."""
     READY_STATES = states.READY_STATES
     UNREADY_STATES = states.UNREADY_STATES
     EXCEPTION_STATES = states.EXCEPTION_STATES