Explorar o código

AMQP backend: Retry sending results

Ask Solem %!s(int64=14) %!d(string=hai) anos
pai
achega
eec3714374
Modificáronse 1 ficheiros con 20 adicións e 7 borrados
  1. 20 7
      celery/backends/amqp.py

+ 20 - 7
celery/backends/amqp.py

@@ -1,6 +1,7 @@
 """celery.backends.amqp"""
 import socket
 import time
+import warnings
 
 from datetime import timedelta
 
@@ -14,6 +15,10 @@ from celery.messaging import establish_connection
 from celery.utils import timeutils
 
 
+class AMQResultWarning(UserWarning):
+    pass
+
+
 class ResultPublisher(Publisher):
     exchange = conf.RESULT_EXCHANGE
     exchange_type = conf.RESULT_EXCHANGE_TYPE
@@ -95,7 +100,8 @@ class AMQPBackend(BaseDictBackend):
                               auto_delete=self.auto_delete,
                               expires=self.expires)
 
-    def store_result(self, task_id, result, status, traceback=None):
+    def store_result(self, task_id, result, status, traceback=None,
+            max_retries=20, retry_delay=0.2):
         """Send task return value and status."""
         result = self.encode_result(result, status)
 
@@ -104,11 +110,19 @@ class AMQPBackend(BaseDictBackend):
                 "status": status,
                 "traceback": traceback}
 
-        publisher = self._create_publisher(task_id, self.connection)
-        try:
-            publisher.send(meta)
-        finally:
-            publisher.close()
+        for i in range(max_retries + 1):
+            try:
+                publisher = self._create_publisher(task_id, self.connection)
+                publisher.send(meta)
+                publisher.close()
+            except Exception, exc:
+                if not max_retries:
+                    raise
+                self._connection = None
+                warnings.warn(AMQResultWarning(
+                    "Error sending result %s: %r" % (task_id, exc)))
+                time.sleep(retry_delay)
+            break
 
         return result
 
@@ -133,7 +147,6 @@ class AMQPBackend(BaseDictBackend):
         else:
             return self.wait_for(task_id, timeout, cache)
 
-
     def poll(self, task_id):
         consumer = self._create_consumer(task_id, self.connection)
         result = consumer.fetch()