Browse Source

TaskSet renamed to group, and is now app-agnostic

Ask Solem 13 years ago
parent
commit
32d4de9f70
1 changed files with 30 additions and 12 deletions
  1. 30 12
      celery/task/sets.py

+ 30 - 12
celery/task/sets.py

@@ -46,7 +46,8 @@ class subtask(AttributeDict):
 
 
     """
     """
 
 
-    def __init__(self, task=None, args=None, kwargs=None, options=None, **ex):
+    def __init__(self, task=None, args=None, kwargs=None, options=None,
+                type=None, **ex):
         init = super(subtask, self).__init__
         init = super(subtask, self).__init__
 
 
         if isinstance(task, dict):
         if isinstance(task, dict):
@@ -55,6 +56,7 @@ class subtask(AttributeDict):
         # Also supports using task class/instance instead of string name.
         # Also supports using task class/instance instead of string name.
         try:
         try:
             task_name = task.name
             task_name = task.name
+            self._type = task
         except AttributeError:
         except AttributeError:
             task_name = task
             task_name = task
 
 
@@ -117,7 +119,7 @@ class subtask(AttributeDict):
 
 
     @cached_property
     @cached_property
     def type(self):
     def type(self):
-        return current_app.tasks[self.task]
+        return self._type or current_app.tasks[self.task]
 
 
 
 
 def maybe_subtask(t):
 def maybe_subtask(t):
@@ -126,7 +128,7 @@ def maybe_subtask(t):
     return t
     return t
 
 
 
 
-class TaskSet(UserList):
+class group(UserList):
     """A task containing several subtasks, making it possible
     """A task containing several subtasks, making it possible
     to track how many, or when all of the tasks have been completed.
     to track how many, or when all of the tasks have been completed.
 
 
@@ -135,19 +137,16 @@ class TaskSet(UserList):
     Example::
     Example::
 
 
         >>> urls = ("http://cnn.com/rss", "http://bbc.co.uk/rss")
         >>> urls = ("http://cnn.com/rss", "http://bbc.co.uk/rss")
-        >>> taskset = TaskSet(refresh_feed.subtask((url, )) for url in urls)
-        >>> taskset_result = taskset.apply_async()
-        >>> list_of_return_values = taskset_result.join()  # *expensive*
+        >>> g = group(refresh_feed.subtask((url, )) for url in urls)
+        >>> group_result = g.apply_async()
+        >>> list_of_return_values = group_result.join()  # *expensive*
 
 
     """
     """
-    #: Total number of subtasks in this set.
-    total = None
 
 
     def __init__(self, tasks=None, app=None, Publisher=None):
     def __init__(self, tasks=None, app=None, Publisher=None):
-        self.app = app_or_default(app)
+        self._app = app
         self.data = [maybe_subtask(t) for t in tasks or []]
         self.data = [maybe_subtask(t) for t in tasks or []]
-        self.total = len(self.tasks)
-        self.Publisher = Publisher or self.app.amqp.TaskPublisher
+        self._Publisher = Publisher
 
 
     def apply_async(self, connection=None, connect_timeout=None,
     def apply_async(self, connection=None, connect_timeout=None,
             publisher=None, taskset_id=None):
             publisher=None, taskset_id=None):
@@ -184,10 +183,29 @@ class TaskSet(UserList):
     def _sync_results(self, taskset_id):
     def _sync_results(self, taskset_id):
         return [task.apply(taskset_id=taskset_id) for task in self.tasks]
         return [task.apply(taskset_id=taskset_id) for task in self.tasks]
 
 
+    @property
+    def total(self):
+        """Number of subtasks in this group."""
+        return len(self)
+
+    def _get_app(self):
+        return self._app or self.data[0].type.app
+
+    def _set_app(self, app):
+        self._app = app
+    app = property(_get_app, _set_app)
+
     def _get_tasks(self):
     def _get_tasks(self):
         return self.data
         return self.data
 
 
     def _set_tasks(self, tasks):
     def _set_tasks(self, tasks):
         self.data = tasks
         self.data = tasks
     tasks = property(_get_tasks, _set_tasks)
     tasks = property(_get_tasks, _set_tasks)
-group = TaskSet
+
+    def _get_Publisher(self):
+        return self._Publisher or self.app.amqp.TaskPublisher
+
+    def _set_Publisher(self, Publisher):
+        self._Publisher = Publisher
+    Publisher = property(_get_Publisher, _set_Publisher)
+TaskSet = group