|
@@ -2,7 +2,6 @@ import sys
|
|
|
from datetime import datetime, timedelta
|
|
|
from Queue import Queue
|
|
|
|
|
|
-from carrot.connection import DjangoBrokerConnection
|
|
|
from billiard.serialization import pickle
|
|
|
|
|
|
from celery import conf
|
|
@@ -13,6 +12,7 @@ from celery.execute import apply_async, apply
|
|
|
from celery.registry import tasks
|
|
|
from celery.backends import default_backend
|
|
|
from celery.messaging import TaskPublisher, TaskConsumer
|
|
|
+from celery.messaging import establish_connection as _establish_connection
|
|
|
from celery.exceptions import MaxRetriesExceededError, RetryTaskError
|
|
|
|
|
|
|
|
@@ -192,7 +192,13 @@ class Task(object):
|
|
|
loglevel = kwargs.get("loglevel")
|
|
|
return setup_logger(loglevel=loglevel, logfile=logfile)
|
|
|
|
|
|
- def get_publisher(self, connect_timeout=conf.AMQP_CONNECTION_TIMEOUT):
|
|
|
+ def establish_connection(self,
|
|
|
+ connect_timeout=conf.AMQP_CONNECTION_TIMEOUT):
|
|
|
+ """Establish a connection to the message broker."""
|
|
|
+ return _establish_connection(connect_timeout)
|
|
|
+
|
|
|
+ def get_publisher(self, connection=None,
|
|
|
+ connect_timeout=conf.AMQP_CONNECTION_TIMEOUT):
|
|
|
"""Get a celery task message publisher.
|
|
|
|
|
|
:rtype: :class:`celery.messaging.TaskPublisher`.
|
|
@@ -205,13 +211,13 @@ class Task(object):
|
|
|
>>> publisher.connection.close()
|
|
|
|
|
|
"""
|
|
|
-
|
|
|
- connection = DjangoBrokerConnection(connect_timeout=connect_timeout)
|
|
|
+ connection = connection or self.establish_connection(connect_timeout)
|
|
|
return TaskPublisher(connection=connection,
|
|
|
exchange=self.exchange,
|
|
|
routing_key=self.routing_key)
|
|
|
|
|
|
- def get_consumer(self, connect_timeout=conf.AMQP_CONNECTION_TIMEOUT):
|
|
|
+ def get_consumer(self, connection=None,
|
|
|
+ connect_timeout=conf.AMQP_CONNECTION_TIMEOUT):
|
|
|
"""Get a celery task message consumer.
|
|
|
|
|
|
:rtype: :class:`celery.messaging.TaskConsumer`.
|
|
@@ -224,7 +230,7 @@ class Task(object):
|
|
|
>>> consumer.connection.close()
|
|
|
|
|
|
"""
|
|
|
- connection = DjangoBrokerConnection(connect_timeout=connect_timeout)
|
|
|
+ connection = connection or self.establish_connection(connect_timeout)
|
|
|
return TaskConsumer(connection=connection, exchange=self.exchange,
|
|
|
routing_key=self.routing_key)
|
|
|
|
|
@@ -463,6 +469,9 @@ class TaskSet(object):
|
|
|
task_name = task
|
|
|
task_obj = tasks[task_name]
|
|
|
|
|
|
+ # Get task instance
|
|
|
+ task_obj = tasks[task_obj.name]
|
|
|
+
|
|
|
self.task = task_obj
|
|
|
self.task_name = task_name
|
|
|
self.arguments = args
|
|
@@ -506,14 +515,15 @@ class TaskSet(object):
|
|
|
for args, kwargs in self.arguments]
|
|
|
return TaskSetResult(taskset_id, subtasks)
|
|
|
|
|
|
- conn = DjangoBrokerConnection(connect_timeout=connect_timeout)
|
|
|
- publisher = TaskPublisher(connection=conn,
|
|
|
- exchange=self.task.exchange)
|
|
|
- subtasks = [apply_async(self.task, args, kwargs,
|
|
|
- taskset_id=taskset_id, publisher=publisher)
|
|
|
- for args, kwargs in self.arguments]
|
|
|
- publisher.close()
|
|
|
- conn.close()
|
|
|
+ conn = self.task.establish_connection(connect_timeout=connect_timeout)
|
|
|
+ publisher = self.task.get_publisher(connection=conn)
|
|
|
+ try:
|
|
|
+ subtasks = [apply_async(self.task, args, kwargs,
|
|
|
+ taskset_id=taskset_id, publisher=publisher)
|
|
|
+ for args, kwargs in self.arguments]
|
|
|
+ finally:
|
|
|
+ publisher.close()
|
|
|
+ conn.close()
|
|
|
return TaskSetResult(taskset_id, subtasks)
|
|
|
|
|
|
@classmethod
|