Ver código fonte

Make with_connection a decorator

Ask Solem 15 anos atrás
pai
commit
3fa0d1bb12
3 arquivos alterados com 37 adições e 24 exclusões
  1. 3 3
      celery/execute.py
  2. 21 1
      celery/messaging.py
  3. 13 20
      celery/task/control.py

+ 3 - 3
celery/execute.py

@@ -10,7 +10,7 @@ from celery.conf import AMQP_CONNECTION_TIMEOUT
 from celery.utils import gen_unique_id, noop, fun_takes_kwargs
 from celery.result import AsyncResult, EagerResult
 from celery.registry import tasks
-from celery.messaging import TaskPublisher, with_connection
+from celery.messaging import TaskPublisher, with_connection_inline
 from celery.exceptions import RetryTaskError
 from celery.datastructures import ExceptionInfo
 
@@ -94,8 +94,8 @@ def apply_async(task, args=None, kwargs=None, countdown=None, eta=None,
         finally:
             publisher or publish.close()
 
-    task_id = with_connection(_delay_task, connection=connection,
-                                           connect_timeout=connect_timeout)
+    task_id = with_connection_inline(_delay_task, connection=connection,
+                                     connect_timeout=connect_timeout)
     return AsyncResult(task_id)
 
 

+ 21 - 1
celery/messaging.py

@@ -3,8 +3,10 @@
 Sending and Receiving Messages
 
 """
+
 from carrot.connection import DjangoBrokerConnection, AMQPConnectionException
 from carrot.messaging import Publisher, Consumer, ConsumerSet
+from billiard.utils.functional import wraps
 
 from celery import conf
 from celery import signals
@@ -109,7 +111,25 @@ def establish_connection(connect_timeout=conf.AMQP_CONNECTION_TIMEOUT):
     return DjangoBrokerConnection(connect_timeout=connect_timeout)
 
 
-def with_connection(fun, connection=None,
+def with_connection(fun):
+
+    @wraps(fun)
+    def _inner(*args, **kwargs):
+        connection = kwargs.get("connection")
+        timeout = kwargs.get("connect_timeout",
+                                conf.AMQP_CONNECTION_TIMEOUT)
+        kwargs["connection"] = conn = connection or \
+                establish_connection(connect_timeout=timeout)
+        close_connection = not connection and conn.close or noop
+
+        try:
+            return fun(*args, **kwargs)
+        finally:
+            close_connection()
+    return _inner
+
+
+def with_connection_inline(fun, connection=None,
         connect_timeout=conf.AMQP_CONNECTION_TIMEOUT):
     conn = connection or establish_connection()
     close_connection = not connection and conn.close or noop

+ 13 - 20
celery/task/control.py

@@ -2,7 +2,8 @@ from celery import conf
 from celery.messaging import TaskConsumer, BroadcastPublisher, with_connection
 
 
-def discard_all(connect_timeout=conf.AMQP_CONNECTION_TIMEOUT):
+@with_connection
+def discard_all(connection=None, connect_timeout=conf.AMQP_CONNECTION_TIMEOUT):
     """Discard all waiting tasks.
 
     This will ignore all tasks waiting for execution, and they will
@@ -11,17 +12,14 @@ def discard_all(connect_timeout=conf.AMQP_CONNECTION_TIMEOUT):
     :returns: the number of tasks discarded.
 
     """
-
-    def _discard(connection):
-        consumer = TaskConsumer(connection=connection)
-        try:
-            return consumer.discard_all()
-        finally:
-            consumer.close()
-
-    return with_connection(_discard, connect_timeout=connect_timeout)
+    consumer = TaskConsumer(connection=connection)
+    try:
+        return consumer.discard_all()
+    finally:
+        consumer.close()
 
 
+@with_connection
 def revoke(task_id, connection=None,
         connect_timeout=conf.AMQP_CONNECTION_TIMEOUT):
     """Revoke a task by id.
@@ -30,13 +28,8 @@ def revoke(task_id, connection=None,
     it after all.
 
     """
-
-    def _revoke(connection):
-        broadcast = BroadcastPublisher(connection)
-        try:
-            broadcast.revoke(task_id)
-        finally:
-            broadcast.close()
-
-    return with_connection(_revoke, connection=connection,
-                           connect_timeout=connect_timeout)
+    broadcast = BroadcastPublisher(connection)
+    try:
+        broadcast.revoke(task_id)
+    finally:
+        broadcast.close()