Browse Source

Added timeout support for amqp backend's .wait()

>>> AsyncResult("nonexistingid").get(timeout=1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/opt/devel/celery/celery/result.py", line 53, in get
    return self.wait(timeout=timeout)
  File "/opt/devel/celery/celery/result.py", line 68, in wait
    return self.backend.wait_for(self.task_id, timeout=timeout)
  File "/opt/devel/celery/celery/backends/amqp.py", line 88, in wait_for
    raise TimeoutError("The operation timed out.")
celery.exceptions.TimeoutError: The operation timed out.
Ask Solem 15 years ago
parent
commit
b56535a70a
3 changed files with 22 additions and 6 deletions
  1. 20 4
      celery/backends/amqp.py
  2. 1 1
      celery/tests/test_worker.py
  3. 1 1
      celery/worker/listener.py

+ 20 - 4
celery/backends/amqp.py

@@ -1,7 +1,11 @@
 """celery.backends.amqp"""
+import socket
+
 from carrot.messaging import Consumer, Publisher
 
 from celery import conf
+from celery import states
+from celery.exceptions import TimeoutError
 from celery.backends.base import BaseDictBackend
 from celery.messaging import establish_connection
 
@@ -47,10 +51,9 @@ class AMQPBackend(BaseDictBackend):
     def _publisher_for_task_id(self, task_id, connection):
         routing_key = task_id.replace("-", "")
         self._declare_queue(task_id, connection)
-        p = Publisher(connection, exchange=self.exchange,
+        return Publisher(connection, exchange=self.exchange,
                       exchange_type="direct",
                       routing_key=routing_key)
-        return p
 
     def _consumer_for_task_id(self, task_id, connection):
         routing_key = task_id.replace("-", "")
@@ -78,7 +81,18 @@ class AMQPBackend(BaseDictBackend):
 
         return result
 
-    def _get_task_meta_for(self, task_id):
+    def wait_for(self, task_id, timeout=None):
+        try:
+            meta = self._get_task_meta_for(task_id, timeout)
+        except socket.timeout:
+            raise TimeoutError("The operation timed out.")
+
+        if meta["status"] == states.SUCCESS:
+            return self.get_result(task_id)
+        elif meta["status"] == states.FAILURE:
+            raise self.get_result(task_id)
+
+    def _get_task_meta_for(self, task_id, timeout=None):
         assert task_id not in self._seen
         self._use_debug_tracking and self._seen.add(task_id)
 
@@ -91,11 +105,13 @@ class AMQPBackend(BaseDictBackend):
         routing_key = task_id.replace("-", "")
 
         connection = self.connection
+        wait = connection.connection.wait_multi
         consumer = self._consumer_for_task_id(task_id, connection)
         consumer.register_callback(callback)
 
+        consumer.consume()
         try:
-            consumer.iterconsume().next()
+            wait([consumer.backend.channel], timeout=timeout)
         finally:
             consumer.backend.channel.queue_delete(routing_key)
             consumer.close()

+ 1 - 1
celery/tests/test_worker.py

@@ -139,7 +139,7 @@ class TestCarrotListener(unittest.TestCase):
             def drain_events(self):
                 return "draining"
 
-        l.connection = PlaceHolder()
+        l.connection = MockConnection()
         l.connection.connection = MockConnection()
 
         it = l._mainloop()

+ 1 - 1
celery/worker/listener.py

@@ -227,7 +227,7 @@ class CarrotListener(object):
 
     def _mainloop(self, **kwargs):
         while 1:
-            yield self.connection.connection.drain_events()
+            yield self.connection.drain_events()
 
     def _detect_wait_method(self):
         if hasattr(self.connection.connection, "drain_events"):