Browse Source

map_async!! This is pretty cool. Example usage:

        >>> from celery.task import map_async
        >>> import operator
        >>> map_async(operator.add, [[2, 2], [4, 4], [8, 8]])
        [4, 8, 16]
Ask Solem 16 years ago
parent
commit
1d9ce63b01
1 changed files with 65 additions and 4 deletions
  1. 65 4
      celery/task.py

+ 65 - 4
celery/task.py

@@ -7,6 +7,8 @@ from celery.models import TaskMeta
 from django.core.cache import cache
 from django.core.cache import cache
 from datetime import timedelta
 from datetime import timedelta
 from celery.backends import default_backend
 from celery.backends import default_backend
+from UserList import UserList
+import time
 import uuid
 import uuid
 import pickle
 import pickle
 import traceback
 import traceback
@@ -176,6 +178,27 @@ class Task(object):
         return delay_task(cls.name, *args, **kwargs)
         return delay_task(cls.name, *args, **kwargs)
 
 
 
 
+class PositionQueue(UserList):
+
+    class UnfilledPosition(object):
+        def __init__(self, position):
+            self.position = position
+
+    def __init__(self, length):
+        self.length = length
+        self.data = map(self.UnfilledPosition, xrange(length))
+
+    def is_full(self):
+        return len(self) >= self.length
+
+    def __len__(self):
+        return len(self.filled)
+
+    @property
+    def filled(self):
+        return filter(lambda v: not isinstance(v, self.UnfilledPosition),
+                      self)
+
 class TaskSet(object):
 class TaskSet(object):
     """A task containing several subtasks, making it possible
     """A task containing several subtasks, making it possible
     to track how many, or when all of the tasks are completed.
     to track how many, or when all of the tasks are completed.
@@ -215,7 +238,10 @@ class TaskSet(object):
 
 
         Examples
         Examples
         --------
         --------
-            >>> ts = RefreshFeeds(["http://foo.com/rss", http://bar.com/rss"])
+            >>> ts = RefreshFeeds([
+            ...         ["http://foo.com/rss", {}],
+            ...         ["http://bar.com/rss", {}],
+            ... )
             >>> taskset_id, subtask_ids = ts.run()
             >>> taskset_id, subtask_ids = ts.run()
             >>> taskset_id
             >>> taskset_id
             "d2c9b261-8eff-4bfb-8459-1e1b72063514"
             "d2c9b261-8eff-4bfb-8459-1e1b72063514"
@@ -229,15 +255,50 @@ class TaskSet(object):
         taskset_id = str(uuid.uuid4())
         taskset_id = str(uuid.uuid4())
         publisher = TaskPublisher(connection=DjangoAMQPConnection())
         publisher = TaskPublisher(connection=DjangoAMQPConnection())
         subtask_ids = []
         subtask_ids = []
-        for arg in self.arguments:
+        for arg, kwarg in self.arguments:
             subtask_id = publisher.delay_task_in_set(task_name=self.task_name,
             subtask_id = publisher.delay_task_in_set(task_name=self.task_name,
                                                      taskset_id=taskset_id,
                                                      taskset_id=taskset_id,
-                                                     task_args=[],
-                                                     task_kwargs=arg)
+                                                     task_args=arg,
+                                                     task_kwargs=kwarg)
             subtask_ids.append(subtask_id) 
             subtask_ids.append(subtask_id) 
         publisher.close()
         publisher.close()
         return taskset_id, subtask_ids
         return taskset_id, subtask_ids
 
 
+    def get_async(self, timeout=None):
+        time_start = time.time()
+        taskset_id, subtask_ids = self.run()
+        pending_results = map(PendingResult, subtask_ids)
+        results = PositionQueue(length=len(subtask_ids))
+
+        while True:
+            for i, pending_result in enumerate(pending_results):
+                if pending_result.status == "DONE":
+                    results[i] = pending_result.result
+                elif pending_result.status == "FAILURE":
+                    raise pending_result.result
+            if results.is_full():
+                return list(results)
+            if timeout and time.time() > time_start + timeout:
+                raise TimeOutError("The map operation timed out.")
+
+
+def map_async(func, args, timeout=None):
+    """Distribute processing of the arguments and collect the results.
+
+    Example
+    --------
+
+        >>> from celery.task import map_async
+        >>> import operator
+        >>> map_async(operator.add, [[2, 2], [4, 4], [8, 8]])
+        [4, 8, 16]
+
+    """
+    pickled = pickle.dumps(func)
+    arguments = [[[pickled, arg, {}], {}] for arg in args]
+    taskset = TaskSet(ExecuteRemoteTask, arguments)
+    return taskset.get_async(timeout=timeout)
+
 
 
 class PeriodicTask(Task):
 class PeriodicTask(Task):
     """A periodic task is a task that behaves like a cron job.
     """A periodic task is a task that behaves like a cron job.