Browse Source

Introducing TaskSets

Ask Solem 16 years ago
parent
commit
88c28ae0cf
2 changed files with 58 additions and 3 deletions
  1. 12 3
      celery/messaging.py
  2. 46 0
      celery/task.py

+ 12 - 3
celery/messaging.py

@@ -15,13 +15,22 @@ class TaskPublisher(Publisher):
     exchange = "celery"
     routing_key = "celery"
 
-    def delay_task(self, task_name, **kwargs):
+    def delay_task(self, task_name, **task_kwargs):
+        return self._delay_task(task_name=task_name, extra_data=task_kwargs)
+
+    def delay_task_in_set(self, task_name, taskset_id, task_kwargs):
+        return self._delay_task(task_name=task_name, part_of_set=taskset_id,
+                                extra_data=task_kwargs)
+
+    def _delay_task(self, task_name, part_of_set=None, extra_data=None):
+        extra_data = extra_data or {}
         task_id = str(uuid.uuid4())
-        message_data = dict(kwargs)
+        message_data = dict(extra_data)
         message_data["celeryTASK"] = task_name
         message_data["celeryID"] = task_id
+        if part_of_set:
+            message_data["celeryTASKSET"] = part_of_set
         self.send(message_data)
-        return task_id
 
 
 class TaskConsumer(NoProcessConsumer):

+ 46 - 0
celery/task.py

@@ -4,6 +4,7 @@ from celery.registry import tasks
 from celery.messaging import TaskPublisher, TaskConsumer
 from django.core.cache import cache
 from datetime import timedelta
+import uuid
 import traceback
 
 __all__ = ["delay_task", "discard_all", "gen_task_done_cache_key",
@@ -83,6 +84,51 @@ class Task(object):
         return delay_task(cls.name, **kwargs)
 
 
+class TaskSet(object):
+    """A task containing several subtasks, making it possible
+    to track how many, or when all of the tasks are completed.
+    
+    Example Usage
+    --------------
+
+        >>> from djangofeeds.tasks import RefreshFeedTask
+        >>> taskset = TaskSet(RefreshFeedTask, args=[
+        ...                 {"feed_url": "http://cnn.com/rss"},
+        ...                 {"feed_url": "http://bbc.com/rss"},
+        ...                 {"feed_url": "http://xkcd.com/rss"}])
+
+        >>> taskset_id = taskset.delay()
+        
+
+    """
+
+    def __init__(self, task, args):
+        """``task`` can be either a fully qualified task name, or a task
+        class, args is a list of arguments for the subtasks.
+        """
+
+        try:
+            task_name = task.name
+        except AttributeError:
+            task_name = task
+
+        self.task_name = task_name
+        self.arguments = args
+        self.total = len(args)
+
+    def run(self):
+        taskset_id = str(uuid.uuid4())
+        publisher = TaskPublisher(connection=DjangoAMQPConnection)
+        subtask_ids = []
+        for arg in self.arguments:
+            subtask_id = publisher.delay_task_in_set(task_name=self.task_name,
+                                                     taskset_id=taskset_id,
+                                                     task_kwargs=arg)
+            subtask_ids.append(subtask_id) 
+        publisher.close()
+        return taskset_id, subtask_ids
+
+
 class PeriodicTask(Task):
     run_every = timedelta(days=1)
     type = "periodic"