|
@@ -4,14 +4,12 @@ Working with tasks and task sets.
|
|
|
|
|
|
"""
|
|
|
from carrot.connection import DjangoAMQPConnection
|
|
|
+from celery.messaging import TaskPublisher, TaskConsumer
|
|
|
from celery.log import setup_logger
|
|
|
from celery.registry import tasks
|
|
|
-from celery.messaging import TaskPublisher, TaskConsumer
|
|
|
from datetime import timedelta
|
|
|
from celery.backends import default_backend
|
|
|
-from celery.datastructures import PositionQueue
|
|
|
-from celery.result import AsyncResult
|
|
|
-from celery.timer import TimeoutTimer
|
|
|
+from celery.result import AsyncResult, TaskSetResult
|
|
|
import uuid
|
|
|
import pickle
|
|
|
|
|
@@ -54,10 +52,10 @@ def apply_async(task, args=None, kwargs=None, routing_key=None,
|
|
|
for option_name, option_value in message_opts.items():
|
|
|
message_opts[option_name] = getattr(task, option_name, option_value)
|
|
|
|
|
|
- amqp_connection = DjangoAMQPConnection(connect_timeout=connect_timeout)
|
|
|
- publisher = TaskPublisher(connection=amqp_connection)
|
|
|
- task_id = publisher.delay_task(task.name, args, kwargs, **message_opts)
|
|
|
- amqp_connection.close()
|
|
|
+ with DjangoAMQPConnection(connect_timeout=connect_timeout) as conn:
|
|
|
+ with TaskPublisher(connection=conn) as publisher:
|
|
|
+ task_id = publisher.delay_task(task.name, args, kwargs,
|
|
|
+ **message_opts)
|
|
|
return AsyncResult(task_id)
|
|
|
|
|
|
|
|
@@ -310,7 +308,7 @@ class TaskSet(object):
|
|
|
... [], {"feed_url": "http://bbc.com/rss"},
|
|
|
... [], {"feed_url": "http://xkcd.com/rss"}])
|
|
|
|
|
|
- >>> taskset_id, subtask_ids = taskset.run()
|
|
|
+ >>> taskset_result = taskset.run()
|
|
|
>>> list_of_return_values = taskset.join()
|
|
|
|
|
|
|
|
@@ -329,39 +327,43 @@ class TaskSet(object):
|
|
|
def run(self):
|
|
|
"""Run all tasks in the taskset.
|
|
|
|
|
|
- :returns: A tuple containing the taskset id, and a list
|
|
|
- of subtask ids.
|
|
|
-
|
|
|
- :rtype: tuple
|
|
|
+ :returns: A :class:`celery.result.TaskSetResult` instance.
|
|
|
|
|
|
Example
|
|
|
|
|
|
- >>> ts = RefreshFeeds([
|
|
|
+ >>> ts = TaskSet(RefreshFeedTask, [
|
|
|
... ["http://foo.com/rss", {}],
|
|
|
... ["http://bar.com/rss", {}],
|
|
|
... )
|
|
|
- >>> taskset_id, subtask_ids = ts.run()
|
|
|
- >>> taskset_id
|
|
|
+ >>> result = ts.run()
|
|
|
+ >>> result.taskset_id
|
|
|
"d2c9b261-8eff-4bfb-8459-1e1b72063514"
|
|
|
- >>> subtask_ids
|
|
|
+ >>> result.subtask_ids
|
|
|
["b4996460-d959-49c8-aeb9-39c530dcde25",
|
|
|
"598d2d18-ab86-45ca-8b4f-0779f5d6a3cb"]
|
|
|
+ >>> result.waiting()
|
|
|
+ True
|
|
|
>>> time.sleep(10)
|
|
|
- >>> is_done(taskset_id)
|
|
|
+ >>> result.ready()
|
|
|
+ True
|
|
|
+ >>> result.successful()
|
|
|
True
|
|
|
+ >>> result.failed()
|
|
|
+ False
|
|
|
+ >>> result.join()
|
|
|
+ [True, True]
|
|
|
+
|
|
|
"""
|
|
|
taskset_id = str(uuid.uuid4())
|
|
|
- amqp_connection = DjangoAMQPConnection()
|
|
|
- publisher = TaskPublisher(connection=amqp_connection)
|
|
|
- subtask_ids = []
|
|
|
- for arg, kwarg in self.arguments:
|
|
|
- subtask_id = publisher.delay_task_in_set(task_name=self.task_name,
|
|
|
- taskset_id=taskset_id,
|
|
|
- task_args=arg,
|
|
|
- task_kwargs=kwarg)
|
|
|
- subtask_ids.append(subtask_id)
|
|
|
- amqp_connection.close()
|
|
|
- return taskset_id, subtask_ids
|
|
|
+ with DjangoAMQPConnection() as amqp_connection:
|
|
|
+ with TaskPublisher(connection=amqp_connection) as publisher:
|
|
|
+ subtask_ids = [publisher.delay_task_in_set(
|
|
|
+ task_name=self.task_name,
|
|
|
+ taskset_id=taskset_id,
|
|
|
+ task_args=arg,
|
|
|
+ task_kwargs=kwarg)
|
|
|
+ for arg, kwarg in self.arguments]
|
|
|
+ return TaskSetResult(taskset_id, subtask_ids)
|
|
|
|
|
|
def iterate(self):
|
|
|
"""Iterate over the results returned after calling :meth:`run`.
|
|
@@ -370,16 +372,7 @@ class TaskSet(object):
|
|
|
be re-raised.
|
|
|
|
|
|
"""
|
|
|
- taskset_id, subtask_ids = self.run()
|
|
|
- results = dict([(task_id, AsyncResult(task_id))
|
|
|
- for task_id in subtask_ids])
|
|
|
- while results:
|
|
|
- for task_id, pending_result in results.items():
|
|
|
- if pending_result.status == "DONE":
|
|
|
- del(results[task_id])
|
|
|
- yield pending_result.result
|
|
|
- elif pending_result.status == "FAILURE":
|
|
|
- raise pending_result.result
|
|
|
+ return iter(self.run())
|
|
|
|
|
|
def join(self, timeout=None):
|
|
|
"""Gather the results for all of the tasks in the taskset,
|
|
@@ -398,24 +391,7 @@ class TaskSet(object):
|
|
|
:returns: list of return values for all tasks in the taskset.
|
|
|
|
|
|
"""
|
|
|
- timeout_timer = TimeoutTimer(timeout) # Timeout timer starts here.
|
|
|
- taskset_id, subtask_ids = self.run()
|
|
|
- pending_results = map(AsyncResult, subtask_ids)
|
|
|
- results = PositionQueue(length=len(subtask_ids))
|
|
|
-
|
|
|
- while True:
|
|
|
- for position, pending_result in enumerate(pending_results):
|
|
|
- if pending_result.status == "DONE":
|
|
|
- results[position] = pending_result.result
|
|
|
- elif pending_result.status == "FAILURE":
|
|
|
- raise pending_result.result
|
|
|
- if results.full():
|
|
|
- # Make list copy, so the returned type is not a position
|
|
|
- # queue.
|
|
|
- return list(results)
|
|
|
-
|
|
|
- # This raises TimeoutError when timed out.
|
|
|
- timeout_timer.tick()
|
|
|
+ return self.run().join(timeout=timeout)
|
|
|
|
|
|
@classmethod
|
|
|
def remote_execute(cls, func, args):
|