Browse Source

map_async!! This is pretty cool. Example usage:

        >>> from celery.task import map_async
        >>> import operator
        >>> map_async(operator.add, [[2, 2], [4, 4], [8, 8]])
        [4, 8, 16]
Ask Solem 16 years ago
parent
commit
1d9ce63b01
1 changed files with 65 additions and 4 deletions
  1. 65 4
      celery/task.py

+ 65 - 4
celery/task.py

@@ -7,6 +7,8 @@ from celery.models import TaskMeta
 from django.core.cache import cache
 from datetime import timedelta
 from celery.backends import default_backend
+from UserList import UserList
+import time
 import uuid
 import pickle
 import traceback
@@ -176,6 +178,27 @@ class Task(object):
         return delay_task(cls.name, *args, **kwargs)
 
 
+class PositionQueue(UserList):
+
+    class UnfilledPosition(object):
+        def __init__(self, position):
+            self.position = position
+
+    def __init__(self, length):
+        self.length = length
+        self.data = map(self.UnfilledPosition, xrange(length))
+
+    def is_full(self):
+        return len(self) >= self.length
+
+    def __len__(self):
+        return len(self.filled)
+
+    @property
+    def filled(self):
+        return filter(lambda v: not isinstance(v, self.UnfilledPosition),
+                      self)
+
 class TaskSet(object):
     """A task containing several subtasks, making it possible
     to track how many, or when all of the tasks are completed.
@@ -215,7 +238,10 @@ class TaskSet(object):
 
         Examples
         --------
-            >>> ts = RefreshFeeds(["http://foo.com/rss", http://bar.com/rss"])
+            >>> ts = RefreshFeeds([
+            ...         ["http://foo.com/rss", {}],
+            ...         ["http://bar.com/rss", {}],
+            ... )
             >>> taskset_id, subtask_ids = ts.run()
             >>> taskset_id
             "d2c9b261-8eff-4bfb-8459-1e1b72063514"
@@ -229,15 +255,50 @@ class TaskSet(object):
         taskset_id = str(uuid.uuid4())
         publisher = TaskPublisher(connection=DjangoAMQPConnection())
         subtask_ids = []
-        for arg in self.arguments:
+        for arg, kwarg in self.arguments:
             subtask_id = publisher.delay_task_in_set(task_name=self.task_name,
                                                      taskset_id=taskset_id,
-                                                     task_args=[],
-                                                     task_kwargs=arg)
+                                                     task_args=arg,
+                                                     task_kwargs=kwarg)
             subtask_ids.append(subtask_id) 
         publisher.close()
         return taskset_id, subtask_ids
 
+    def get_async(self, timeout=None):
+        time_start = time.time()
+        taskset_id, subtask_ids = self.run()
+        pending_results = map(PendingResult, subtask_ids)
+        results = PositionQueue(length=len(subtask_ids))
+
+        while True:
+            for i, pending_result in enumerate(pending_results):
+                if pending_result.status == "DONE":
+                    results[i] = pending_result.result
+                elif pending_result.status == "FAILURE":
+                    raise pending_result.result
+            if results.is_full():
+                return list(results)
+            if timeout and time.time() > time_start + timeout:
+                raise TimeOutError("The map operation timed out.")
+
+
+def map_async(func, args, timeout=None):
+    """Distribute processing of the arguments and collect the results.
+
+    Example
+    --------
+
+        >>> from celery.task import map_async
+        >>> import operator
+        >>> map_async(operator.add, [[2, 2], [4, 4], [8, 8]])
+        [4, 8, 16]
+
+    """
+    pickled = pickle.dumps(func)
+    arguments = [[[pickled, arg, {}], {}] for arg in args]
+    taskset = TaskSet(ExecuteRemoteTask, arguments)
+    return taskset.get_async(timeout=timeout)
+
 
 class PeriodicTask(Task):
     """A periodic task is a task that behaves like a cron job.