Browse Source

Timeouts refactored into celery.timer.TimeoutTimer

Ask Solem 16 years ago
parent
commit
06bfab9844
3 changed files with 24 additions and 84 deletions
  1. 3 11
      celery/backends/base.py
  2. 8 73
      celery/task.py
  3. 13 0
      celery/timer.py

+ 3 - 11
celery/backends/base.py

@@ -1,14 +1,8 @@
-import time
-
-
-class TimeOutError(Exception):
-    """The operation has timed out."""
+from celery.timer import TimeoutTimer
 
 
 
 
 class BaseBackend(object):
 class BaseBackend(object):
 
 
-    TimeOutError = TimeOutError
-    
     def __init__(self):
     def __init__(self):
         pass
         pass
 
 
@@ -45,13 +39,11 @@ class BaseBackend(object):
         pass
         pass
 
 
     def wait_for(self, task_id, timeout=None):
     def wait_for(self, task_id, timeout=None):
-        time_start = time.time()
+        timeout_timer = TimeoutTimer(timeout)
         while True:
         while True:
             status = self.get_status(task_id)
             status = self.get_status(task_id)
             if status == "DONE":
             if status == "DONE":
                 return self.get_result(task_id)
                 return self.get_result(task_id)
             elif status == "FAILURE":
             elif status == "FAILURE":
                 raise self.get_result(task_id)
                 raise self.get_result(task_id)
-            if timeout and time.time() > time_start + timeout:
-                raise self.TimeOutError(
-                        "Timed out while waiting for task %s" % (task_id))
+            timeout_timer.tick()

+ 8 - 73
celery/task.py

@@ -8,72 +8,10 @@ from django.core.cache import cache
 from datetime import timedelta
 from datetime import timedelta
 from celery.backends import default_backend
 from celery.backends import default_backend
 from celery.datastructures import PositionQueue
 from celery.datastructures import PositionQueue
-import time
+from celery.result import AsyncResult
+from celery.timer import TimeoutTimer
 import uuid
 import uuid
 import pickle
 import pickle
-import traceback
-
-
-class BaseAsyncResult(object):
-    """Base class for pending result, takes ``backend`` argument."""
-
-    def __init__(self, task_id, backend):
-        self.task_id = task_id
-        self.backend = backend
-
-    def is_done(self):
-        """Returns ``True`` if the task executed successfully."""
-        return self.backend.is_done(self.task_id)
-
-    def get(self):
-        """Alias to ``wait``."""
-        return self.wait()
-
-    def wait(self, timeout=None):
-        """Return the result when it arrives.
-        
-        If timeout is not ``None`` and the result does not arrive within
-        ``timeout`` seconds then ``celery.backends.base.TimeoutError`` is
-        raised. If the remote call raised an exception then that exception
-        will be reraised by get()."""
-        return self.backend.wait_for(self.task_id, timeout=timeout)
-
-    def ready(self):
-        """Returns ``True`` if the task executed successfully, or raised
-        an exception. If the task is still pending, or is waiting for retry
-        then ``False`` is returned."""
-        status = self.backend.get_status(self.task_id)
-        return status != "PENDING" or status != "RETRY"
-
-    def successful(self):
-        """Alias to ``is_done``."""
-        return self.is_done()
-
-    def __str__(self):
-        """str(self) -> self.task_id"""
-        return self.task_id
-
-    def __repr__(self):
-        return "<AsyncResult: %s>" % self.task_id
-
-    @property
-    def result(self):
-        """The tasks resulting value."""
-        if self.status == "DONE" or self.status == "FAILURE":
-            return self.backend.get_result(self.task_id)
-        return None
-
-    @property
-    def status(self):
-        """The current status of the task."""
-        return self.backend.get_status(self.task_id)
-
-
-class AsyncResult(BaseAsyncResult):
-    """Pending task result using the default backend.""" 
-
-    def __init__(self, task_id):
-        super(AsyncResult, self).__init__(task_id, backend=default_backend)
 
 
 
 
 def delay_task(task_name, *args, **kwargs):
 def delay_task(task_name, *args, **kwargs):
@@ -209,8 +147,6 @@ class Task(object):
         return delay_task(cls.name, *args, **kwargs)
         return delay_task(cls.name, *args, **kwargs)
 
 
 
 
-
-
 class TaskSet(object):
 class TaskSet(object):
     """A task containing several subtasks, making it possible
     """A task containing several subtasks, making it possible
     to track how many, or when all of the tasks are completed.
     to track how many, or when all of the tasks are completed.
@@ -276,7 +212,7 @@ class TaskSet(object):
         publisher.close()
         publisher.close()
         return taskset_id, subtask_ids
         return taskset_id, subtask_ids
 
 
-    def xget(self):
+    def iterate(self):
         taskset_id, subtask_ids = self.run()
         taskset_id, subtask_ids = self.run()
         results = dict([(task_id, AsyncResult(task_id))
         results = dict([(task_id, AsyncResult(task_id))
                             for task_id in subtask_ids])
                             for task_id in subtask_ids])
@@ -288,7 +224,7 @@ class TaskSet(object):
                     raise pending_result.result
                     raise pending_result.result
 
 
     def join(self, timeout=None):
     def join(self, timeout=None):
-        time_start = time.time()
+        timeout_timer = TimeOutTimer(timeout)
         taskset_id, subtask_ids = self.run()
         taskset_id, subtask_ids = self.run()
         pending_results = map(AsyncResult, subtask_ids)
         pending_results = map(AsyncResult, subtask_ids)
         results = PositionQueue(length=len(subtask_ids))
         results = PositionQueue(length=len(subtask_ids))
@@ -301,8 +237,7 @@ class TaskSet(object):
                     raise pending_result.result
                     raise pending_result.result
             if results.is_full():
             if results.is_full():
                 return list(results)
                 return list(results)
-            if timeout and time.time() > time_start + timeout:
-                raise TimeOutError("The map operation timed out.")
+            timeout_timer.tick()
 
 
     @classmethod
     @classmethod
     def remote_execute(cls, func, args):
     def remote_execute(cls, func, args):
@@ -321,7 +256,6 @@ class TaskSet(object):
         return AsynchronousMapTask.delay(serfunc, args, timeout=timeout)
         return AsynchronousMapTask.delay(serfunc, args, timeout=timeout)
 
 
 
 
-
 def dmap(func, args, timeout=None):
 def dmap(func, args, timeout=None):
     """Distribute processing of the arguments and collect the results.
     """Distribute processing of the arguments and collect the results.
 
 
@@ -336,16 +270,16 @@ def dmap(func, args, timeout=None):
     """
     """
     return TaskSet.map(func, args, timeout=timeout)
     return TaskSet.map(func, args, timeout=timeout)
 
 
+
 class AsynchronousMapTask(Task):
 class AsynchronousMapTask(Task):
     name = "celery.map_async"
     name = "celery.map_async"
 
 
     def run(self, serfunc, args, **kwargs):
     def run(self, serfunc, args, **kwargs):
         timeout = kwargs.get("timeout")
         timeout = kwargs.get("timeout")
-        logger = self.get_logger(**kwargs)
-        logger.info("<<<<<<< ASYNCMAP: %s(%s)" % (serfunc, args))
         return TaskSet.map(pickle.loads(serfunc), args, timeout=timeout)
         return TaskSet.map(pickle.loads(serfunc), args, timeout=timeout)
 tasks.register(AsynchronousMapTask)
 tasks.register(AsynchronousMapTask)
 
 
+
 def dmap_async(func, args, timeout=None):
 def dmap_async(func, args, timeout=None):
     """Distribute processing of the arguments and collect the results
     """Distribute processing of the arguments and collect the results
     asynchronously. Returns a :class:`AsyncResult` object.
     asynchronously. Returns a :class:`AsyncResult` object.
@@ -419,6 +353,7 @@ tasks.register(ExecuteRemoteTask)
 def execute_remote(func, *args, **kwargs):
 def execute_remote(func, *args, **kwargs):
     return ExecuteRemoteTask.delay(pickle.dumps(func), args, kwargs)
     return ExecuteRemoteTask.delay(pickle.dumps(func), args, kwargs)
 
 
+
 class DeleteExpiredTaskMetaTask(PeriodicTask):
 class DeleteExpiredTaskMetaTask(PeriodicTask):
     """A periodic task that deletes expired task metadata every day.
     """A periodic task that deletes expired task metadata every day.
 
 

+ 13 - 0
celery/timer.py

@@ -22,3 +22,16 @@ class EventTimer(object):
             self.last_triggered = time.time()
             self.last_triggered = time.time()
 
 
 
 
+class TimeoutTimer(object):
+    """A timer that raises ``TimeoutError`` when the time has run out."""
+
+    def __init__(self, timeout):
+        self.timeout = timeout
+        self.time_start = time.time()
+
+    def tick(self):
+        if not self.timeout:
+            return
+        if time.time() > self.time_start + self.timeout:
+            raise TimeoutError("The operation timed out.")
+