Browse Source

Merge branch 'f0rk/f0rk/imap'

Ask Solem 13 years ago
parent
commit
e38562aff3
1 changed files with 37 additions and 12 deletions
  1. 37 12
      celery/concurrency/processes/pool.py

+ 37 - 12
celery/concurrency/processes/pool.py

@@ -592,7 +592,7 @@ class Pool(object):
         w.start()
         return w
 
-    def _join_exited_workers(self, shutdown=False, lost_worker_timeout=10.0):
+    def _join_exited_workers(self, shutdown=False):
         """Cleanup after any worker processes which have exited due to
         reaching their specified lifetime. Returns True if any workers were
         cleaned up.
@@ -600,11 +600,12 @@ class Pool(object):
         now = None
         # The worker may have published a result before being terminated,
         # but we have no way to accurately tell if it did.  So we wait for
-        # 10 seconds before we mark the job with WorkerLostError.
+        # _lost_worker_timeout seconds before we mark the job with
+        # WorkerLostError.
         for job in [job for job in self._cache.values()
                 if not job.ready() and job._worker_lost]:
             now = now or time.time()
-            if now - job._worker_lost > lost_worker_timeout:
+            if now - job._worker_lost > job._lost_worker_timeout:
                 exc_info = None
                 try:
                     raise WorkerLostError("Worker exited prematurely.")
@@ -730,39 +731,44 @@ class Pool(object):
         assert self._state == RUN
         return self.map_async(func, iterable, chunksize).get()
 
-    def imap(self, func, iterable, chunksize=1):
+    def imap(self, func, iterable, chunksize=1, lost_worker_timeout=10.0):
         '''
         Equivalent of `itertools.imap()` -- can be MUCH slower
         than `Pool.map()`
         '''
         assert self._state == RUN
         if chunksize == 1:
-            result = IMapIterator(self._cache)
+            result = IMapIterator(self._cache,
+                                  lost_worker_timeout=lost_worker_timeout)
             self._taskqueue.put((((result._job, i, func, (x,), {})
                          for i, x in enumerate(iterable)), result._set_length))
             return result
         else:
             assert chunksize > 1
             task_batches = Pool._get_tasks(func, iterable, chunksize)
-            result = IMapIterator(self._cache)
+            result = IMapIterator(self._cache,
+                                  lost_worker_timeout=lost_worker_timeout)
             self._taskqueue.put((((result._job, i, mapstar, (x,), {})
                      for i, x in enumerate(task_batches)), result._set_length))
             return (item for chunk in result for item in chunk)
 
-    def imap_unordered(self, func, iterable, chunksize=1):
+    def imap_unordered(self, func, iterable, chunksize=1,
+                       lost_worker_timeout=10.0):
         '''
         Like `imap()` method but ordering of results is arbitrary
         '''
         assert self._state == RUN
         if chunksize == 1:
-            result = IMapUnorderedIterator(self._cache)
+            result = IMapUnorderedIterator(self._cache,
+                    lost_worker_timeout=lost_worker_timeout)
             self._taskqueue.put((((result._job, i, func, (x,), {})
                          for i, x in enumerate(iterable)), result._set_length))
             return result
         else:
             assert chunksize > 1
             task_batches = Pool._get_tasks(func, iterable, chunksize)
-            result = IMapUnorderedIterator(self._cache)
+            result = IMapUnorderedIterator(self._cache,
+                    lost_worker_timeout=lost_worker_timeout)
             self._taskqueue.put((((result._job, i, mapstar, (x,), {})
                      for i, x in enumerate(task_batches)), result._set_length))
             return (item for chunk in result for item in chunk)
@@ -939,7 +945,7 @@ class ApplyResult(object):
 
     def __init__(self, cache, callback, accept_callback=None,
             timeout_callback=None, error_callback=None, soft_timeout=None,
-            timeout=None):
+            timeout=None, lost_worker_timeout=10.0):
         self._mutex = threading.Lock()
         self._cond = threading.Condition(threading.Lock())
         self._job = job_counter.next()
@@ -951,6 +957,7 @@ class ApplyResult(object):
         self._timeout_callback = timeout_callback
         self._timeout = timeout
         self._soft_timeout = soft_timeout
+        self._lost_worker_timeout = lost_worker_timeout
 
         self._accepted = False
         self._worker_pid = None
@@ -1094,15 +1101,19 @@ class MapResult(ApplyResult):
 
 
 class IMapIterator(object):
+    _worker_lost = None
 
-    def __init__(self, cache):
+    def __init__(self, cache, lost_worker_timeout=10.0):
         self._cond = threading.Condition(threading.Lock())
         self._job = job_counter.next()
         self._cache = cache
         self._items = collections.deque()
         self._index = 0
         self._length = None
+        self._ready = False
         self._unsorted = {}
+        self._worker_pids = []
+        self._lost_worker_timeout = lost_worker_timeout
         cache[self._job] = self
 
     def __iter__(self):
@@ -1115,12 +1126,14 @@ class IMapIterator(object):
                 item = self._items.popleft()
             except IndexError:
                 if self._index == self._length:
+                    self._ready = True
                     raise StopIteration
                 self._cond.wait(timeout)
                 try:
                     item = self._items.popleft()
                 except IndexError:
                     if self._index == self._length:
+                        self._ready = True
                         raise StopIteration
                     raise TimeoutError
         finally:
@@ -1129,7 +1142,7 @@ class IMapIterator(object):
         success, value = item
         if success:
             return value
-        raise value
+        raise Exception(value)
 
     __next__ = next                    # XXX
 
@@ -1148,6 +1161,7 @@ class IMapIterator(object):
                 self._unsorted[i] = obj
 
             if self._index == self._length:
+                self._ready = True
                 del self._cache[self._job]
         finally:
             self._cond.release()
@@ -1157,11 +1171,21 @@ class IMapIterator(object):
         try:
             self._length = length
             if self._index == self._length:
+                self._ready = True
                 self._cond.notify()
                 del self._cache[self._job]
         finally:
             self._cond.release()
 
+    def _ack(self, i, time_accepted, pid):
+        self._worker_pids.append(pid)
+
+    def ready(self):
+        return self._ready
+
+    def worker_pids(self):
+        return self._worker_pids
+
 #
 # Class whose instances are returned by `Pool.imap_unordered()`
 #
@@ -1176,6 +1200,7 @@ class IMapUnorderedIterator(IMapIterator):
             self._index += 1
             self._cond.notify()
             if self._index == self._length:
+                self._ready = True
                 del self._cache[self._job]
         finally:
             self._cond.release()