Explorar el Código

TaskSets can now contain different kinds of tasks.

The previous invocation of ``TaskSet(task, args)`` is still supported,
but undocumented and will probably be deprecated in the future.

The new TaskSet syntax is:

    >>> from celery.task.sets import TaskSet, subtask

    >>> ts = TaskSet(tasks=(
    ...     subtask(tasks.add, args=[2, 2])
    ...     subtask(tasks.mul, args=[4, 4]),
    ...     subtask(tasks.add, args=[16, 16]),
    )
    >>> result = ts.apply_async()
    >>> result.join()
    [4, 16, 32]

Where the subtask task takes the following arguments:
    task, args, kwargs, options

Note that this implementation does not take overriden ``get_publisher``,
or ``establish_connection`` methods in account.

Closes #116. Thanks to jonozzz.
Ask Solem hace 15 años
padre
commit
3f44e1dba1
Se han modificado 1 ficheros con 67 adiciones y 55 borrados
  1. 67 55
      celery/task/sets.py

+ 67 - 55
celery/task/sets.py

@@ -1,74 +1,93 @@
 from celery import conf
 from celery.execute import apply_async
+from celery.messaging import establish_connection, with_connection
+from celery.messaging import TaskPublisher
 from celery.registry import tasks
 from celery.result import TaskSetResult
 from celery.utils import gen_unique_id, padlist
 
 
-class TaskSet(object):
-    """A task containing several subtasks, making it possible
-    to track how many, or when all of the tasks has been completed.
+class subtask(object):
+    """A subtask part of a :class:`TaskSet`.
+
+    :param task: The task class.
+    :keyword args: Positional arguments to apply.
+    :keyword kwargs: Keyword arguments to apply.
+    :keyword options: Additional options to
+      :func:`celery.execute.apply_async`.
 
-    :param task: The task class or name.
-        Can either be a fully qualified task name, or a task class.
+    """
 
-    :param args: A list of args, kwargs pairs.
-        e.g. ``[[args1, kwargs1], [args2, kwargs2], ..., [argsN, kwargsN]]``
+    def __init__(self, task, args=None, kwargs=None, options=None):
+        self.task = task
+        self.args = args or ()
+        self.kwargs = kwargs or {}
+        self.options = options or {}
 
+    def apply(self, taskset_id):
+        """Apply this task locally."""
+        return self.task.apply(self.args, self.kwargs,
+                               taskset_id=taskset_id, **self.options)
 
-    .. attribute:: task_name
+    def apply_async(self, taskset_id, publisher):
+        """Apply this task asynchronously."""
+        return self.task.apply_async(self.args, self.kwargs,
+                                     taskset_id=taskset_id,
+                                     publisher=publisher, **self.options)
 
-        The name of the task.
 
-    .. attribute:: arguments
+class TaskSet(object):
+    """A task containing several subtasks, making it possible
+    to track how many, or when all of the tasks has been completed.
 
-        The arguments, as passed to the task set constructor.
+    :param tasks: A list of :class:`subtask`s.
 
     .. attribute:: total
 
-        Total number of tasks in this task set.
+        Total number of subtasks in this task set.
 
     Example
 
         >>> from djangofeeds.tasks import RefreshFeedTask
-        >>> taskset = TaskSet(RefreshFeedTask, args=[
-        ...                 ([], {"feed_url": "http://cnn.com/rss"}),
-        ...                 ([], {"feed_url": "http://bbc.com/rss"}),
-        ...                 ([], {"feed_url": "http://xkcd.com/rss"})
-        ... ])
-
+        >>> from celery.task.sets import TaskSet, subtask
+        >>> urls = ("http://cnn.com/rss",
+        ...         "http://bbc.co.uk/rss",
+        ...         "http://xkcd.com/rss")
+        >>> subtasks = [subtask(RefreshFeedTask, kwargs={"feed_url": url})
+        ...                 for url in urls]
+        >>> taskset = TaskSet(tasks=subtasks)
         >>> taskset_result = taskset.apply_async()
         >>> list_of_return_values = taskset_result.join()
 
     """
-
-    def __init__(self, task, args):
-        try:
-            task_name = task.name
-            task_obj = task
-        except AttributeError:
-            task_name = task
-            task_obj = tasks[task_name]
-
-        # Get task instance
-        task_obj = tasks[task_obj.name]
-
-        self.task = task_obj
-        self.task_name = task_name
-        self.arguments = args
-        self.total = len(args)
-
-    def apply_async(self, connect_timeout=conf.BROKER_CONNECTION_TIMEOUT):
+    task = None # compat
+    task_name = None # compat
+
+    def __init__(self, task=None, tasks=None):
+        # Previously TaskSet only supported applying one kind of task.
+        # the signature then was TaskSet(task, arglist)
+        # Convert the arguments to subtasks'.
+        if task is not None:
+            tasks = [subtask(task, *arglist) for arglist in tasks]
+            self.task = task
+            self.task_name = task.name
+
+        self.tasks = tasks
+        self.total = len(self.tasks)
+
+    @with_connection
+    def apply_async(self, connection=None,
+            connect_timeout=conf.BROKER_CONNECTION_TIMEOUT):
         """Run all tasks in the taskset.
 
         :returns: A :class:`celery.result.TaskSetResult` instance.
 
         Example
 
-            >>> ts = TaskSet(RefreshFeedTask, args=[
-            ...         (["http://foo.com/rss"], {}),
-            ...         (["http://bar.com/rss"], {}),
-            ... ])
+            >>> ts = TaskSet(tasks=(
+            ...         subtask(RefreshFeedTask, ["http://foo.com/rss"]),
+            ...         subtask(RefreshFeedTask, ["http://bar.com/rss"]),
+            ... ))
             >>> result = ts.apply_async()
             >>> result.taskset_id
             "d2c9b261-8eff-4bfb-8459-1e1b72063514"
@@ -92,28 +111,21 @@ class TaskSet(object):
             return self.apply()
 
         taskset_id = gen_unique_id()
-        conn = self.task.establish_connection(connect_timeout=connect_timeout)
-        publisher = self.task.get_publisher(connection=conn)
+        conn = connection or establish_connection(connect_timeout=connect_timeout)
+        publisher = TaskPublisher(connection=conn)
         try:
-            subtasks = [self.apply_part(arglist, taskset_id, publisher)
-                            for arglist in self.arguments]
+            results = [task.apply_async(taskset_id, publisher)
+                            for task in self.tasks]
         finally:
             publisher.close()
-            conn.close()
-
-        return TaskSetResult(taskset_id, subtasks)
+            connection or conn.close()
 
-    def apply_part(self, arglist, taskset_id, publisher):
-        """Apply a single part of the taskset."""
-        args, kwargs, opts = padlist(arglist, 3, default={})
-        return apply_async(self.task, args, kwargs,
-                           taskset_id=taskset_id, publisher=publisher, **opts)
+        return TaskSetResult(taskset_id, results)
 
     def apply(self):
         """Applies the taskset locally."""
         taskset_id = gen_unique_id()
-        subtasks = [apply(self.task, args, kwargs)
-                        for args, kwargs in self.arguments]
 
         # This will be filled with EagerResults.
-        return TaskSetResult(taskset_id, subtasks)
+        return TaskSetResult(taskset_id, [task.apply(taskset_id)
+                                            for task in self.tasks])