|
@@ -12,13 +12,14 @@ from celery.registry import tasks
|
|
|
from datetime import timedelta
|
|
|
from celery.backends import default_backend
|
|
|
from celery.result import AsyncResult, TaskSetResult
|
|
|
+from django.utils.functional import curry
|
|
|
import uuid
|
|
|
import pickle
|
|
|
|
|
|
|
|
|
def apply_async(task, args=None, kwargs=None, routing_key=None,
|
|
|
immediate=None, mandatory=None, connection=None,
|
|
|
- connect_timeout=AMQP_CONNECTION_TIMEOUT, priority=None):
|
|
|
+ connect_timeout=AMQP_CONNECTION_TIMEOUT, priority=None, **opts):
|
|
|
"""Run a task asynchronously by the celery daemon(s).
|
|
|
|
|
|
:param task: The task to run (a callable object, or a :class:`Task`
|
|
@@ -47,27 +48,34 @@ def apply_async(task, args=None, kwargs=None, routing_key=None,
|
|
|
:keyword priority: The task priority, a number between ``0`` and ``9``.
|
|
|
|
|
|
"""
|
|
|
- if not args:
|
|
|
- args = []
|
|
|
- if not kwargs:
|
|
|
- kwargs = []
|
|
|
- message_opts = {"routing_key": routing_key,
|
|
|
- "immediate": immediate,
|
|
|
- "mandatory": mandatory,
|
|
|
- "priority": priority}
|
|
|
- for option_name, option_value in message_opts.items():
|
|
|
- message_opts[option_name] = getattr(task, option_name, option_value)
|
|
|
+ args = args or []
|
|
|
+ kwargs = kwargs or {}
|
|
|
+ routing_key = routing_key or getattr(task, "routing_key", None)
|
|
|
+ immediate = immediate or getattr(task, "immediate", None)
|
|
|
+ mandatory = mandatory or getattr(task, "mandatory", None)
|
|
|
+ priority = priority or getattr(task, "priority", None)
|
|
|
+ taskset_id = opts.get("taskset_id")
|
|
|
+ publisher = opts.get("publisher")
|
|
|
|
|
|
need_to_close_connection = False
|
|
|
- if not connection:
|
|
|
- connection = DjangoAMQPConnection(connect_timeout=connect_timeout)
|
|
|
- need_to_close_connection = True
|
|
|
+ if not publisher:
|
|
|
+ if not connection:
|
|
|
+ connection = DjangoAMQPConnection(connect_timeout=connect_timeout)
|
|
|
+ need_to_close_connection = True
|
|
|
+ publisher = TaskPublisher(connection=connection)
|
|
|
+
|
|
|
+ delay_task = publisher.delay_task
|
|
|
+ if taskset_id:
|
|
|
+ delay_task = curry(publisher.delay_task_in_set, taskset_id)
|
|
|
+
|
|
|
+ task_id = delay_task(task.name, args, kwargs,
|
|
|
+ routing_key=routing_key, mandatory=mandatory,
|
|
|
+ immediate=immediate, priority=priority)
|
|
|
|
|
|
- publisher = TaskPublisher(connection=connection)
|
|
|
- task_id = publisher.delay_task(task.name, args, kwargs, **message_opts)
|
|
|
- publisher.close()
|
|
|
if need_to_close_connection:
|
|
|
+ publisher.close()
|
|
|
connection.close()
|
|
|
+
|
|
|
return AsyncResult(task_id)
|
|
|
|
|
|
|
|
@@ -323,9 +331,12 @@ class TaskSet(object):
|
|
|
def __init__(self, task, args):
|
|
|
try:
|
|
|
task_name = task.name
|
|
|
+ task_obj = task
|
|
|
except AttributeError:
|
|
|
task_name = task
|
|
|
+ task_obj = tasks[task_name]
|
|
|
|
|
|
+ self.task = task_obj
|
|
|
self.task_name = task_name
|
|
|
self.arguments = args
|
|
|
self.total = len(args)
|
|
@@ -363,14 +374,12 @@ class TaskSet(object):
|
|
|
taskset_id = str(uuid.uuid4())
|
|
|
conn = DjangoAMQPConnection(connect_timeout=connect_timeout)
|
|
|
publisher = TaskPublisher(connection=conn)
|
|
|
- subtask_ids = [publisher.delay_task_in_set(task_name=self.task_name,
|
|
|
- taskset_id=taskset_id,
|
|
|
- task_args=arg,
|
|
|
- task_kwargs=kwarg)
|
|
|
- for arg, kwarg in self.arguments]
|
|
|
+ subtasks = [apply_async(self.task, args, kwargs,
|
|
|
+ taskset_id=taskset_id, publisher=publisher)
|
|
|
+ for args, kwargs in self.arguments]
|
|
|
publisher.close()
|
|
|
conn.close()
|
|
|
- return TaskSetResult(taskset_id, subtask_ids)
|
|
|
+ return TaskSetResult(taskset_id, subtasks)
|
|
|
|
|
|
def iterate(self):
|
|
|
"""Iterate over the results returned after calling :meth:`run`.
|