|  | @@ -1,74 +1,93 @@
 | 
	
		
			
				|  |  |  from celery import conf
 | 
	
		
			
				|  |  |  from celery.execute import apply_async
 | 
	
		
			
				|  |  | +from celery.messaging import establish_connection, with_connection
 | 
	
		
			
				|  |  | +from celery.messaging import TaskPublisher
 | 
	
		
			
				|  |  |  from celery.registry import tasks
 | 
	
		
			
				|  |  |  from celery.result import TaskSetResult
 | 
	
		
			
				|  |  |  from celery.utils import gen_unique_id, padlist
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -class TaskSet(object):
 | 
	
		
			
				|  |  | -    """A task containing several subtasks, making it possible
 | 
	
		
			
				|  |  | -    to track how many, or when all of the tasks has been completed.
 | 
	
		
			
				|  |  | +class subtask(object):
 | 
	
		
			
				|  |  | +    """A subtask part of a :class:`TaskSet`.
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    :param task: The task class.
 | 
	
		
			
				|  |  | +    :keyword args: Positional arguments to apply.
 | 
	
		
			
				|  |  | +    :keyword kwargs: Keyword arguments to apply.
 | 
	
		
			
				|  |  | +    :keyword options: Additional options to
 | 
	
		
			
				|  |  | +      :func:`celery.execute.apply_async`.
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    :param task: The task class or name.
 | 
	
		
			
				|  |  | -        Can either be a fully qualified task name, or a task class.
 | 
	
		
			
				|  |  | +    """
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    :param args: A list of args, kwargs pairs.
 | 
	
		
			
				|  |  | -        e.g. ``[[args1, kwargs1], [args2, kwargs2], ..., [argsN, kwargsN]]``
 | 
	
		
			
				|  |  | +    def __init__(self, task, args=None, kwargs=None, options=None):
 | 
	
		
			
				|  |  | +        self.task = task
 | 
	
		
			
				|  |  | +        self.args = args or ()
 | 
	
		
			
				|  |  | +        self.kwargs = kwargs or {}
 | 
	
		
			
				|  |  | +        self.options = options or {}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    def apply(self, taskset_id):
 | 
	
		
			
				|  |  | +        """Apply this task locally."""
 | 
	
		
			
				|  |  | +        return self.task.apply(self.args, self.kwargs,
 | 
	
		
			
				|  |  | +                               taskset_id=taskset_id, **self.options)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    .. attribute:: task_name
 | 
	
		
			
				|  |  | +    def apply_async(self, taskset_id, publisher):
 | 
	
		
			
				|  |  | +        """Apply this task asynchronously."""
 | 
	
		
			
				|  |  | +        return self.task.apply_async(self.args, self.kwargs,
 | 
	
		
			
				|  |  | +                                     taskset_id=taskset_id,
 | 
	
		
			
				|  |  | +                                     publisher=publisher, **self.options)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        The name of the task.
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    .. attribute:: arguments
 | 
	
		
			
				|  |  | +class TaskSet(object):
 | 
	
		
			
				|  |  | +    """A task containing several subtasks, making it possible
 | 
	
		
			
				|  |  | +    to track how many, or when all of the tasks has been completed.
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        The arguments, as passed to the task set constructor.
 | 
	
		
			
				|  |  | +    :param tasks: A list of :class:`subtask`s.
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      .. attribute:: total
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        Total number of tasks in this task set.
 | 
	
		
			
				|  |  | +        Total number of subtasks in this task set.
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      Example
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          >>> from djangofeeds.tasks import RefreshFeedTask
 | 
	
		
			
				|  |  | -        >>> taskset = TaskSet(RefreshFeedTask, args=[
 | 
	
		
			
				|  |  | -        ...                 ([], {"feed_url": "http://cnn.com/rss"}),
 | 
	
		
			
				|  |  | -        ...                 ([], {"feed_url": "http://bbc.com/rss"}),
 | 
	
		
			
				|  |  | -        ...                 ([], {"feed_url": "http://xkcd.com/rss"})
 | 
	
		
			
				|  |  | -        ... ])
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | +        >>> from celery.task.sets import TaskSet, subtask
 | 
	
		
			
				|  |  | +        >>> urls = ("http://cnn.com/rss",
 | 
	
		
			
				|  |  | +        ...         "http://bbc.co.uk/rss",
 | 
	
		
			
				|  |  | +        ...         "http://xkcd.com/rss")
 | 
	
		
			
				|  |  | +        >>> subtasks = [subtask(RefreshFeedTask, kwargs={"feed_url": url})
 | 
	
		
			
				|  |  | +        ...                 for url in urls]
 | 
	
		
			
				|  |  | +        >>> taskset = TaskSet(tasks=subtasks)
 | 
	
		
			
				|  |  |          >>> taskset_result = taskset.apply_async()
 | 
	
		
			
				|  |  |          >>> list_of_return_values = taskset_result.join()
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      """
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -    def __init__(self, task, args):
 | 
	
		
			
				|  |  | -        try:
 | 
	
		
			
				|  |  | -            task_name = task.name
 | 
	
		
			
				|  |  | -            task_obj = task
 | 
	
		
			
				|  |  | -        except AttributeError:
 | 
	
		
			
				|  |  | -            task_name = task
 | 
	
		
			
				|  |  | -            task_obj = tasks[task_name]
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        # Get task instance
 | 
	
		
			
				|  |  | -        task_obj = tasks[task_obj.name]
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        self.task = task_obj
 | 
	
		
			
				|  |  | -        self.task_name = task_name
 | 
	
		
			
				|  |  | -        self.arguments = args
 | 
	
		
			
				|  |  | -        self.total = len(args)
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -    def apply_async(self, connect_timeout=conf.BROKER_CONNECTION_TIMEOUT):
 | 
	
		
			
				|  |  | +    task = None # compat
 | 
	
		
			
				|  |  | +    task_name = None # compat
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def __init__(self, task=None, tasks=None):
 | 
	
		
			
				|  |  | +        # Previously TaskSet only supported applying one kind of task.
 | 
	
		
			
				|  |  | +        # the signature then was TaskSet(task, arglist)
 | 
	
		
			
				|  |  | +        # Convert the arguments to subtasks'.
 | 
	
		
			
				|  |  | +        if task is not None:
 | 
	
		
			
				|  |  | +            tasks = [subtask(task, *arglist) for arglist in tasks]
 | 
	
		
			
				|  |  | +            self.task = task
 | 
	
		
			
				|  |  | +            self.task_name = task.name
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        self.tasks = tasks
 | 
	
		
			
				|  |  | +        self.total = len(self.tasks)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    @with_connection
 | 
	
		
			
				|  |  | +    def apply_async(self, connection=None,
 | 
	
		
			
				|  |  | +            connect_timeout=conf.BROKER_CONNECTION_TIMEOUT):
 | 
	
		
			
				|  |  |          """Run all tasks in the taskset.
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          :returns: A :class:`celery.result.TaskSetResult` instance.
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          Example
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -            >>> ts = TaskSet(RefreshFeedTask, args=[
 | 
	
		
			
				|  |  | -            ...         (["http://foo.com/rss"], {}),
 | 
	
		
			
				|  |  | -            ...         (["http://bar.com/rss"], {}),
 | 
	
		
			
				|  |  | -            ... ])
 | 
	
		
			
				|  |  | +            >>> ts = TaskSet(tasks=(
 | 
	
		
			
				|  |  | +            ...         subtask(RefreshFeedTask, ["http://foo.com/rss"]),
 | 
	
		
			
				|  |  | +            ...         subtask(RefreshFeedTask, ["http://bar.com/rss"]),
 | 
	
		
			
				|  |  | +            ... ))
 | 
	
		
			
				|  |  |              >>> result = ts.apply_async()
 | 
	
		
			
				|  |  |              >>> result.taskset_id
 | 
	
		
			
				|  |  |              "d2c9b261-8eff-4bfb-8459-1e1b72063514"
 | 
	
	
		
			
				|  | @@ -92,28 +111,21 @@ class TaskSet(object):
 | 
	
		
			
				|  |  |              return self.apply()
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          taskset_id = gen_unique_id()
 | 
	
		
			
				|  |  | -        conn = self.task.establish_connection(connect_timeout=connect_timeout)
 | 
	
		
			
				|  |  | -        publisher = self.task.get_publisher(connection=conn)
 | 
	
		
			
				|  |  | +        conn = connection or establish_connection(connect_timeout=connect_timeout)
 | 
	
		
			
				|  |  | +        publisher = TaskPublisher(connection=conn)
 | 
	
		
			
				|  |  |          try:
 | 
	
		
			
				|  |  | -            subtasks = [self.apply_part(arglist, taskset_id, publisher)
 | 
	
		
			
				|  |  | -                            for arglist in self.arguments]
 | 
	
		
			
				|  |  | +            results = [task.apply_async(taskset_id, publisher)
 | 
	
		
			
				|  |  | +                            for task in self.tasks]
 | 
	
		
			
				|  |  |          finally:
 | 
	
		
			
				|  |  |              publisher.close()
 | 
	
		
			
				|  |  | -            conn.close()
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        return TaskSetResult(taskset_id, subtasks)
 | 
	
		
			
				|  |  | +            connection or conn.close()
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    def apply_part(self, arglist, taskset_id, publisher):
 | 
	
		
			
				|  |  | -        """Apply a single part of the taskset."""
 | 
	
		
			
				|  |  | -        args, kwargs, opts = padlist(arglist, 3, default={})
 | 
	
		
			
				|  |  | -        return apply_async(self.task, args, kwargs,
 | 
	
		
			
				|  |  | -                           taskset_id=taskset_id, publisher=publisher, **opts)
 | 
	
		
			
				|  |  | +        return TaskSetResult(taskset_id, results)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def apply(self):
 | 
	
		
			
				|  |  |          """Applies the taskset locally."""
 | 
	
		
			
				|  |  |          taskset_id = gen_unique_id()
 | 
	
		
			
				|  |  | -        subtasks = [apply(self.task, args, kwargs)
 | 
	
		
			
				|  |  | -                        for args, kwargs in self.arguments]
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          # This will be filled with EagerResults.
 | 
	
		
			
				|  |  | -        return TaskSetResult(taskset_id, subtasks)
 | 
	
		
			
				|  |  | +        return TaskSetResult(taskset_id, [task.apply(taskset_id)
 | 
	
		
			
				|  |  | +                                            for task in self.tasks])
 |