Browse Source

TaskSet renamed to group, and is now app-agnostic

Ask Solem 13 years ago
parent
commit
32d4de9f70
1 changed files with 30 additions and 12 deletions
  1. 30 12
      celery/task/sets.py

+ 30 - 12
celery/task/sets.py

@@ -46,7 +46,8 @@ class subtask(AttributeDict):
 
     """
 
-    def __init__(self, task=None, args=None, kwargs=None, options=None, **ex):
+    def __init__(self, task=None, args=None, kwargs=None, options=None,
+                type=None, **ex):
         init = super(subtask, self).__init__
 
         if isinstance(task, dict):
@@ -55,6 +56,7 @@ class subtask(AttributeDict):
         # Also supports using task class/instance instead of string name.
         try:
             task_name = task.name
+            self._type = task
         except AttributeError:
             task_name = task
 
@@ -117,7 +119,7 @@ class subtask(AttributeDict):
 
     @cached_property
     def type(self):
-        return current_app.tasks[self.task]
+        return self._type or current_app.tasks[self.task]
 
 
 def maybe_subtask(t):
@@ -126,7 +128,7 @@ def maybe_subtask(t):
     return t
 
 
-class TaskSet(UserList):
+class group(UserList):
     """A task containing several subtasks, making it possible
     to track how many, or when all of the tasks have been completed.
 
@@ -135,19 +137,16 @@ class TaskSet(UserList):
     Example::
 
         >>> urls = ("http://cnn.com/rss", "http://bbc.co.uk/rss")
-        >>> taskset = TaskSet(refresh_feed.subtask((url, )) for url in urls)
-        >>> taskset_result = taskset.apply_async()
-        >>> list_of_return_values = taskset_result.join()  # *expensive*
+        >>> g = group(refresh_feed.subtask((url, )) for url in urls)
+        >>> group_result = g.apply_async()
+        >>> list_of_return_values = group_result.join()  # *expensive*
 
     """
-    #: Total number of subtasks in this set.
-    total = None
 
     def __init__(self, tasks=None, app=None, Publisher=None):
-        self.app = app_or_default(app)
+        self._app = app
         self.data = [maybe_subtask(t) for t in tasks or []]
-        self.total = len(self.tasks)
-        self.Publisher = Publisher or self.app.amqp.TaskPublisher
+        self._Publisher = Publisher
 
     def apply_async(self, connection=None, connect_timeout=None,
             publisher=None, taskset_id=None):
@@ -184,10 +183,29 @@ class TaskSet(UserList):
     def _sync_results(self, taskset_id):
         return [task.apply(taskset_id=taskset_id) for task in self.tasks]
 
+    @property
+    def total(self):
+        """Number of subtasks in this group."""
+        return len(self)
+
+    def _get_app(self):
+        return self._app or self.data[0].type.app
+
+    def _set_app(self, app):
+        self._app = app
+    app = property(_get_app, _set_app)
+
     def _get_tasks(self):
         return self.data
 
     def _set_tasks(self, tasks):
         self.data = tasks
     tasks = property(_get_tasks, _set_tasks)
-group = TaskSet
+
+    def _get_Publisher(self):
+        return self._Publisher or self.app.amqp.TaskPublisher
+
+    def _set_Publisher(self, Publisher):
+        self._Publisher = Publisher
+    Publisher = property(_get_Publisher, _set_Publisher)
+TaskSet = group