Browse Source

Renamed PendingResult -> AsyncResult, also added multiprocessing.AsyncResult
like method aliases. Now got dmap and dmap_async. Also some more docstrings.

Ask Solem 16 years ago
parent
commit
2857d858c9
1 changed files with 114 additions and 22 deletions
  1. 114 22
      celery/task.py

+ 114 - 22
celery/task.py

@@ -14,40 +14,66 @@ import pickle
 import traceback
 import traceback
 
 
 
 
-class BasePendingResult(object):
+class BaseAsyncResult(object):
     """Base class for pending result, takes ``backend`` argument."""
     """Base class for pending result, takes ``backend`` argument."""
+
     def __init__(self, task_id, backend):
     def __init__(self, task_id, backend):
         self.task_id = task_id
         self.task_id = task_id
         self.backend = backend
         self.backend = backend
 
 
+    def is_done(self):
+        """Returns ``True`` if the task executed successfully."""
+        return self.backend.is_done(self.task_id)
+
+    def get(self):
+        """Alias to ``wait_for``."""
+        return self.wait_for()
+
+    def wait_for(self, timeout=None):
+        """Return the result when it arrives.
+        
+        If timeout is not ``None`` and the result does not arrive within
+        ``timeout`` seconds then ``celery.backends.base.TimeoutError`` is
+        raised. If the remote call raised an exception then that exception
+        will be reraised by get()."""
+        return self.backend.wait_for(self.task_id, timeout=timeout)
+
+    def ready(self):
+        """Returns ``True`` if the task executed successfully, or raised
+        an exception. If the task is still pending, or is waiting for retry
+        then ``False`` is returned."""
+        status = self.backend.get_status(self.task_id)
+        return status != "PENDING" or status != "RETRY"
+
+    def successful(self):
+        """Alias to ``is_done``."""
+        return self.is_done()
+
     def __str__(self):
     def __str__(self):
+        """str(self) -> self.task_id"""
         return self.task_id
         return self.task_id
 
 
     def __repr__(self):
     def __repr__(self):
-        return "<Job: %s>" % self.task_id
-
-    def is_done(self):
-        return self.backend.is_done(self.task_id)
-
-    def wait_for(self):
-        return self.backend.wait_for(self.task_id)
+        return "<AsyncResult: %s>" % self.task_id
 
 
     @property
     @property
     def result(self):
     def result(self):
-        if self.status == "DONE":
+        """The tasks resulting value."""
+        if self.status == "DONE" or self.status == "FAILURE":
             return self.backend.get_result(self.task_id)
             return self.backend.get_result(self.task_id)
         return None
         return None
 
 
     @property
     @property
     def status(self):
     def status(self):
+        """The current status of the task."""
         return self.backend.get_status(self.task_id)
         return self.backend.get_status(self.task_id)
 
 
 
 
-class PendingResult(BasePendingResult):
+class AsyncResult(BaseAsyncResult):
     """Pending task result using the default backend.""" 
     """Pending task result using the default backend.""" 
 
 
     def __init__(self, task_id):
     def __init__(self, task_id):
-        super(PendingResult, self).__init__(task_id, backend=default_backend)
+        super(AsyncResult, self).__init__(task_id, backend=default_backend)
 
 
 
 
 def delay_task(task_name, *args, **kwargs):
 def delay_task(task_name, *args, **kwargs):
@@ -65,7 +91,7 @@ def delay_task(task_name, *args, **kwargs):
     publisher = TaskPublisher(connection=DjangoAMQPConnection())
     publisher = TaskPublisher(connection=DjangoAMQPConnection())
     task_id = publisher.delay_task(task_name, *args, **kwargs)
     task_id = publisher.delay_task(task_name, *args, **kwargs)
     publisher.close()
     publisher.close()
-    return PendingResult(task_id)
+    return AsyncResult(task_id)
 
 
 
 
 def discard_all():
 def discard_all():
@@ -120,11 +146,16 @@ class Task(object):
         ...         logger = self.get_logger(**kwargs)
         ...         logger = self.get_logger(**kwargs)
         ...         logger.info("Running MyTask with arg some_arg=%s" %
         ...         logger.info("Running MyTask with arg some_arg=%s" %
         ...                     some_arg))
         ...                     some_arg))
+        ...         return 42
         ... tasks.register(MyTask)
         ... tasks.register(MyTask)
 
 
     You can delay the task using the classmethod ``delay``...
     You can delay the task using the classmethod ``delay``...
 
 
-        >>> MyTask.delay(some_arg="foo")
+        >>> result = MyTask.delay(some_arg="foo")
+        >>> result.status # after some time
+        'DONE'
+        >>> result.result
+        42
 
 
     ...or using the ``celery.task.delay_task`` function, by passing the
     ...or using the ``celery.task.delay_task`` function, by passing the
     name of the task.
     name of the task.
@@ -179,26 +210,33 @@ class Task(object):
 
 
 
 
 class PositionQueue(UserList):
 class PositionQueue(UserList):
+    """A positional queue with filled/unfilled slots."""
 
 
     class UnfilledPosition(object):
     class UnfilledPosition(object):
+        """Describes an unfilled slot."""
         def __init__(self, position):
         def __init__(self, position):
             self.position = position
             self.position = position
 
 
     def __init__(self, length):
     def __init__(self, length):
+        """Initialize a position queue with ``length`` slots."""
         self.length = length
         self.length = length
         self.data = map(self.UnfilledPosition, xrange(length))
         self.data = map(self.UnfilledPosition, xrange(length))
 
 
     def is_full(self):
     def is_full(self):
+        """Returns ``True`` if all the positions has been filled."""
         return len(self) >= self.length
         return len(self) >= self.length
 
 
     def __len__(self):
     def __len__(self):
+        """len(self) -> number of positions filled with real values."""
         return len(self.filled)
         return len(self.filled)
 
 
     @property
     @property
     def filled(self):
     def filled(self):
+        """Returns the filled slots as a list."""
         return filter(lambda v: not isinstance(v, self.UnfilledPosition),
         return filter(lambda v: not isinstance(v, self.UnfilledPosition),
                       self)
                       self)
 
 
+
 class TaskSet(object):
 class TaskSet(object):
     """A task containing several subtasks, making it possible
     """A task containing several subtasks, making it possible
     to track how many, or when all of the tasks are completed.
     to track how many, or when all of the tasks are completed.
@@ -264,10 +302,21 @@ class TaskSet(object):
         publisher.close()
         publisher.close()
         return taskset_id, subtask_ids
         return taskset_id, subtask_ids
 
 
-    def get_async(self, timeout=None):
+    def xget(self):
+        taskset_id, subtask_ids = self.run()
+        results = dict([(task_id, AsyncResult(task_id))
+                            for task_id in subtask_ids])
+        while results:
+            for pending_result in results:
+                if pending_result.status == "DONE":
+                    yield pending_result.result
+                elif pending_result.status == "FAILURE":
+                    raise pending_result.result
+
+    def join(self, timeout=None):
         time_start = time.time()
         time_start = time.time()
         taskset_id, subtask_ids = self.run()
         taskset_id, subtask_ids = self.run()
-        pending_results = map(PendingResult, subtask_ids)
+        pending_results = map(AsyncResult, subtask_ids)
         results = PositionQueue(length=len(subtask_ids))
         results = PositionQueue(length=len(subtask_ids))
 
 
         while True:
         while True:
@@ -281,24 +330,67 @@ class TaskSet(object):
             if timeout and time.time() > time_start + timeout:
             if timeout and time.time() > time_start + timeout:
                 raise TimeOutError("The map operation timed out.")
                 raise TimeOutError("The map operation timed out.")
 
 
+    @classmethod
+    def remote_execute(cls, func, args):
+        pickled = pickle.dumps(func)
+        arguments = [[[pickled, arg, {}], {}] for arg in args]
+        return cls(ExecuteRemoteTask, arguments)
 
 
-def map_async(func, args, timeout=None):
+    @classmethod
+    def map(cls, func, args, timeout=None):
+        remote_task = cls.remote_execute(func, args)
+        return remote_task.join(timeout=timeout)
+
+    @classmethod
+    def map_async(cls, func, args, timeout=None):
+        serfunc = pickle.dumps(func)
+        return AsynchronousMapTask.delay(serfunc, args, timeout=timeout)
+
+
+
+def dmap(func, args, timeout=None):
     """Distribute processing of the arguments and collect the results.
     """Distribute processing of the arguments and collect the results.
 
 
     Example
     Example
     --------
     --------
 
 
-        >>> from celery.task import map_async
+        >>> from celery.task import map
         >>> import operator
         >>> import operator
-        >>> map_async(operator.add, [[2, 2], [4, 4], [8, 8]])
+        >>> dmap(operator.add, [[2, 2], [4, 4], [8, 8]])
         [4, 8, 16]
         [4, 8, 16]
 
 
     """
     """
-    pickled = pickle.dumps(func)
-    arguments = [[[pickled, arg, {}], {}] for arg in args]
-    taskset = TaskSet(ExecuteRemoteTask, arguments)
-    return taskset.get_async(timeout=timeout)
+    return TaskSet.map(func, args, timeout=timeout)
+
+class AsynchronousMapTask(Task):
+    name = "celery.map_async"
+
+    def run(self, serfunc, args, **kwargs):
+        timeout = kwargs.get("timeout")
+        logger = self.get_logger(**kwargs)
+        logger.info("<<<<<<< ASYNCMAP: %s(%s)" % (serfunc, args))
+        return TaskSet.map(pickle.loads(serfunc), args, timeout=timeout)
+tasks.register(AsynchronousMapTask)
 
 
+def dmap_async(func, args, timeout=None):
+    """Distribute processing of the arguments and collect the results
+    asynchronously. Returns a :class:`AsyncResult` object.
+
+    Example
+    --------
+
+        >>> from celery.task import dmap_async
+        >>> import operator
+        >>> presult = dmap_async(operator.add, [[2, 2], [4, 4], [8, 8]])
+        >>> presult
+        <AsyncResult: 373550e8-b9a0-4666-bc61-ace01fa4f91d>
+        >>> presult.status
+        'DONE'
+        >>> presult.result
+        [4, 8, 16]
+
+    """
+    return TaskSet.map_async(func, args, timeout=timeout)
 
 
 class PeriodicTask(Task):
 class PeriodicTask(Task):
     """A periodic task is a task that behaves like a cron job.
     """A periodic task is a task that behaves like a cron job.