Browse Source

Get DjangoBrokerConnection from one place: celery.messaging.establish_connection

Ask Solem 15 years ago
parent
commit
5772b2bc98

+ 2 - 2
celery/backends/amqp.py

@@ -1,7 +1,7 @@
 """celery.backends.amqp"""
 from carrot.messaging import Consumer, Publisher
-from carrot.connection import DjangoBrokerConnection
 
+from celery.messaging import establish_connection
 from celery.backends.base import BaseBackend
 
 RESULTSTORE_EXCHANGE = "celeryresults"
@@ -21,7 +21,7 @@ class AMQPBackend(BaseBackend):
 
     def __init__(self, *args, **kwargs):
         super(AMQPBackend, self).__init__(*args, **kwargs)
-        self.connection = DjangoBrokerConnection()
+        self.connection = establish_connection()
         self._cache = {}
 
     def _declare_queue(self, task_id, connection):

+ 0 - 2
celery/bin/celeryd.py

@@ -67,8 +67,6 @@ import multiprocessing
 import traceback
 import optparse
 
-from carrot.connection import DjangoBrokerConnection
-
 from celery import conf
 from celery import platform
 from celery import __version__

+ 0 - 1
celery/execute.py

@@ -3,7 +3,6 @@ import inspect
 import traceback
 from datetime import datetime, timedelta
 
-from carrot.connection import DjangoBrokerConnection
 from billiard.utils.functional import curry
 
 from celery import signals

+ 1 - 1
celery/messaging.py

@@ -3,7 +3,7 @@
 Sending and Receiving Messages
 
 """
-from carrot.connection import DjangoBrokerConnection
+from carrot.connection import DjangoBrokerConnection, AMQPConnectionException
 from carrot.messaging import Publisher, Consumer, ConsumerSet
 
 from celery import conf

+ 9 - 6
celery/task/__init__.py

@@ -3,7 +3,6 @@
 Working with tasks and task sets.
 
 """
-from carrot.connection import DjangoBrokerConnection
 from billiard.serialization import pickle
 
 from celery.conf import AMQP_CONNECTION_TIMEOUT
@@ -28,11 +27,15 @@ def discard_all(connect_timeout=AMQP_CONNECTION_TIMEOUT):
     :rtype: int
 
     """
-    amqp_connection = DjangoBrokerConnection(connect_timeout=connect_timeout)
-    consumer = TaskConsumer(connection=amqp_connection)
-    discarded_count = consumer.discard_all()
-    amqp_connection.close()
-    return discarded_count
+
+    def _discard(connection):
+        consumer = TaskConsumer(connection=connection)
+        try:
+            return consumer.discard_all()
+        finally:
+            consumer.close()
+
+    return with_connection(_discard, connect_timeout=connect_timeout)
 
 
 def revoke(task_id, connection=None, connect_timeout=None):

+ 24 - 14
celery/task/base.py

@@ -2,7 +2,6 @@ import sys
 from datetime import datetime, timedelta
 from Queue import Queue
 
-from carrot.connection import DjangoBrokerConnection
 from billiard.serialization import pickle
 
 from celery import conf
@@ -13,6 +12,7 @@ from celery.execute import apply_async, apply
 from celery.registry import tasks
 from celery.backends import default_backend
 from celery.messaging import TaskPublisher, TaskConsumer
+from celery.messaging import establish_connection as _establish_connection
 from celery.exceptions import MaxRetriesExceededError, RetryTaskError
 
 
@@ -192,7 +192,13 @@ class Task(object):
         loglevel = kwargs.get("loglevel")
         return setup_logger(loglevel=loglevel, logfile=logfile)
 
-    def get_publisher(self, connect_timeout=conf.AMQP_CONNECTION_TIMEOUT):
+    def establish_connection(self,
+            connect_timeout=conf.AMQP_CONNECTION_TIMEOUT):
+        """Establish a connection to the message broker."""
+        return _establish_connection(connect_timeout)
+
+    def get_publisher(self, connection=None,
+            connect_timeout=conf.AMQP_CONNECTION_TIMEOUT):
         """Get a celery task message publisher.
 
         :rtype: :class:`celery.messaging.TaskPublisher`.
@@ -205,13 +211,13 @@ class Task(object):
             >>> publisher.connection.close()
 
         """
-
-        connection = DjangoBrokerConnection(connect_timeout=connect_timeout)
+        connection = connection or self.establish_connection(connect_timeout)
         return TaskPublisher(connection=connection,
                              exchange=self.exchange,
                              routing_key=self.routing_key)
 
-    def get_consumer(self, connect_timeout=conf.AMQP_CONNECTION_TIMEOUT):
+    def get_consumer(self, connection=None,
+            connect_timeout=conf.AMQP_CONNECTION_TIMEOUT):
         """Get a celery task message consumer.
 
         :rtype: :class:`celery.messaging.TaskConsumer`.
@@ -224,7 +230,7 @@ class Task(object):
             >>> consumer.connection.close()
 
         """
-        connection = DjangoBrokerConnection(connect_timeout=connect_timeout)
+        connection = connection or self.establish_connection(connect_timeout)
         return TaskConsumer(connection=connection, exchange=self.exchange,
                             routing_key=self.routing_key)
 
@@ -463,6 +469,9 @@ class TaskSet(object):
             task_name = task
             task_obj = tasks[task_name]
 
+        # Get task instance
+        task_obj = tasks[task_obj.name]
+
         self.task = task_obj
         self.task_name = task_name
         self.arguments = args
@@ -506,14 +515,15 @@ class TaskSet(object):
                             for args, kwargs in self.arguments]
             return TaskSetResult(taskset_id, subtasks)
 
-        conn = DjangoBrokerConnection(connect_timeout=connect_timeout)
-        publisher = TaskPublisher(connection=conn,
-                                  exchange=self.task.exchange)
-        subtasks = [apply_async(self.task, args, kwargs,
-                                taskset_id=taskset_id, publisher=publisher)
-                        for args, kwargs in self.arguments]
-        publisher.close()
-        conn.close()
+        conn = self.task.establish_connection(connect_timeout=connect_timeout)
+        publisher = self.task.get_publisher(connection=conn)
+        try:
+            subtasks = [apply_async(self.task, args, kwargs,
+                                    taskset_id=taskset_id, publisher=publisher)
+                            for args, kwargs in self.arguments]
+        finally:
+            publisher.close()
+            conn.close()
         return TaskSetResult(taskset_id, subtasks)
 
     @classmethod

+ 1 - 3
celery/task/strategy.py

@@ -1,5 +1,3 @@
-from carrot.connection import DjangoBrokerConnection
-
 from celery.utils import chunks
 
 
@@ -38,7 +36,7 @@ def even_time_distribution(task, size, time_window, iterable, **apply_kwargs):
     bucketsize = size / time_window
     buckets = chunks(iterable, int(bucketsize))
 
-    connection = DjangoBrokerConnection()
+    connection = task.establish_connection()
     try:
         for bucket_count, bucket in enumerate(buckets):
             # Skew the countdown for items in this bucket by one.

+ 2 - 2
celery/worker/listener.py

@@ -2,7 +2,6 @@ import socket
 from datetime import datetime
 
 from dateutil.parser import parse as parse_iso8601
-from carrot.connection import DjangoBrokerConnection, AMQPConnectionException
 
 from celery import conf
 from celery import signals
@@ -12,6 +11,7 @@ from celery.worker.revoke import revoked
 from celery.worker.control import ControlDispatch
 from celery.worker.heartbeat import Heart
 from celery.events import EventDispatcher
+from celery.messaging import establish_connection, AMQPConnectionException
 from celery.messaging import get_consumer_set, BroadcastConsumer
 from celery.exceptions import NotRegistered
 from celery.datastructures import SharedCounter
@@ -200,7 +200,7 @@ class CarrotListener(object):
 
         def _establish_connection():
             """Establish a connection to the AMQP broker."""
-            conn = DjangoBrokerConnection()
+            conn = establish_connection()
             connected = conn.connection # Connection is established lazily.
             return conn
 

+ 6 - 7
docs/userguide/executing.rst

@@ -101,12 +101,12 @@ establish the connection yourself and pass it to ``apply_async``:
 
 .. code-block:: python
 
-    from carrot.connection import DjangoBrokerConnection
+    from celery.messaging import establish_connection
 
     numbers = [(2, 2), (4, 4), (8, 8), (16, 16)]
 
     results = []
-    connection = DjangoBrokerConnection()
+    connection = establish_connection()
     try:
         for args in numbers:
             res = AddTask.apply_async(args=args, connection=connection)
@@ -122,12 +122,12 @@ In Python 2.5 and above, you can use the ``with`` statement:
 .. code-block:: python
 
     from __future__ import with_statement
-    from carrot.connection import DjangoBrokerConnection
+    from celery.messaging import establish_connection
 
     numbers = [(2, 2), (4, 4), (8, 8), (16, 16)]
 
     results = []
-    with DjangoBrokerConnection() as connection:
+    with establish_connection() as connection:
         for args in numbers:
             res = AddTask.apply_async(args=args, connection=connection)
             results.append(res)
@@ -146,12 +146,11 @@ argument to ``apply_async``:
 
     AddTask.apply_async([10, 10], connect_timeout=3)
 
-or if you handle your connection manually by using the connection objects
-``timeout`` argument:
+or if you handle the connection manually:
 
 .. code-block:: python
 
-    connection = DjangoAMQPConnection(timeout=3)
+    connection = establish_connection(connect_timeout=3)
 
 
 Routing options