Sfoglia il codice sorgente

AsynPool: Non-blocking read of result queue

Ask Solem 11 anni fa
parent
commit
5b6a096838
1 ha cambiato i file con 104 aggiunte e 47 eliminazioni
  1. 104 47
      celery/concurrency/asynpool.py

+ 104 - 47
celery/concurrency/asynpool.py

@@ -26,16 +26,14 @@ import socket
 import struct
 import time
 
-from collections import defaultdict, deque, namedtuple
+from collections import deque, namedtuple
+from io import BytesIO
 from pickle import HIGHEST_PROTOCOL
 from time import sleep
 from weakref import WeakValueDictionary, ref
 
 from amqp.utils import promise
-from billiard.pool import (
-    RUN, TERMINATE, ACK, NACK, EX_RECYCLE,
-    WorkersJoined, CoroStop,
-)
+from billiard.pool import RUN, TERMINATE, ACK, NACK, EX_RECYCLE, WorkersJoined
 from billiard import pool as _pool
 from billiard.einfo import ExceptionInfo
 from billiard.queues import _SimpleQueue
@@ -138,7 +136,7 @@ class Worker(_pool.Worker):
 
     def prepare_result(self, result):
         if not isinstance(result, ExceptionInfo):
-            return truncate(repr(result), 64)
+            return truncate(repr(result), 46)
 
 
 class ResultHandler(_pool.ResultHandler):
@@ -151,12 +149,58 @@ class ResultHandler(_pool.ResultHandler):
         # add our custom message handler
         self.state_handlers[WORKER_UP] = self.on_process_alive
 
+    def _recv_message(self, add_reader, fd, callback,
+                      read=os.read, unpack=struct.unpack,
+                      loads=_pickle.loads, BytesIO=BytesIO):
+        buf = BytesIO()
+        # header
+        remaining = 4
+        bsize = None
+        while remaining > 0:
+            try:
+                bsize = read(fd, remaining)
+            except OSError as exc:
+                if get_errno(exc) not in UNAVAIL:
+                    raise
+                yield
+            else:
+                n = len(bsize)
+                if n == 0:
+                    if remaining == 4:
+                        raise EOFError()
+                    else:
+                        raise OSError("Got end of file during message")
+                remaining -= n
+
+        remaining, = size, = unpack('>i', bsize)
+        while remaining > 0:
+            try:
+                chunk = read(fd, remaining)
+            except OSError as exc:
+                if get_errno(exc) not in UNAVAIL:
+                    raise
+                yield
+            n = len(chunk)
+            if n == 0:
+                if remaining == size:
+                    raise EOFError()
+                else:
+                    raise IOError('Got end of file during message')
+            buf.write(chunk)
+            remaining -= n
+        add_reader(fd, self.handle_event, fd)
+        message = loads(buf.getvalue())
+        if message:
+            callback(message)
+
     def _make_process_result(self, hub):
         """Coroutine that reads messages from the pool processes
         and calls the appropriate handler."""
         fileno_to_outq = self.fileno_to_outq
         on_state_change = self.on_state_change
+        add_reader = hub.add_reader
         hub_remove = hub.remove
+        recv_message = self._recv_message
 
         def on_readable(fileno):
             try:
@@ -166,23 +210,17 @@ class ResultHandler(_pool.ResultHandler):
                 # process gone
                 return
             reader = proc.outq._reader
+            reader.setblocking(0)
 
+            it = recv_message(add_reader, fileno, on_state_change)
             try:
-                message = reader.recv()
-            except (IOError, EOFError) as exc:
-                debug('result handler got %r -- exiting', exc)
+                next(it)
+            except StopIteration:
+                pass
+            except (IOError, OSError, EOFError):
                 hub_remove(fileno)
-                return
-
-            if self._state:
-                assert self._state == TERMINATE
-                debug('result handler found thread._state==TERMINATE')
-                return
-
-            if message is None:
-                debug('result handler got sentinel -- exiting')
-                return
-            on_state_change(message)
+            else:
+                add_reader(fileno, it)
         return on_readable
 
     def register_with_event_loop(self, hub):
@@ -217,6 +255,7 @@ class ResultHandler(_pool.ResultHandler):
                     break
 
                 reader = proc.outq._reader
+                reader.setblocking(1)
                 try:
                     if reader.poll(0):
                         task = reader.recv()
@@ -229,6 +268,9 @@ class ResultHandler(_pool.ResultHandler):
                 else:
                     if task:
                         on_state_change(task)
+                finally:
+                    reader.setblocking(0)
+
                 try:
                     join_exited_workers(shutdown=True)
                 except WorkersJoined:
@@ -281,6 +323,8 @@ class AsynPool(_pool.Pool):
         # Holds jobs waiting to be written to child processes.
         self.outbound_buffer = deque()
 
+        self.write_stats = Counter()
+
         super(AsynPool, self).__init__(processes, *args, **kwargs)
 
         for proc in self._pool:
@@ -490,15 +534,17 @@ class AsynPool(_pool.Pool):
         mark_worker_as_busy = busy_workers.add
         write_generator_done = active_writers.discard
         get_job = self._cache.__getitem__
+        write_stats = self.write_stats
+        is_fair_strategy = self.sched_strategy == SCHED_STRATEGY_FAIR
+
+        precalc = {ACK: self._create_payload(ACK, (0, )),
+                   NACK: self._create_payload(NACK, (0, ))}
 
         def _put_back(job):
             # puts back at the end of the queue
             if job not in outbound:  # XXX slow, should find another way
                 outbound.appendleft(job)
         self._put_back = _put_back
-        precalc = {ACK: self._create_payload(ACK, (0, )),
-                   NACK: self._create_payload(NACK, (0, ))}
-        is_fair_strategy = self.sched_strategy == SCHED_STRATEGY_FAIR
 
         def on_poll_start():
             # called for every event loop iteration, and if there
@@ -606,12 +652,12 @@ class AsynPool(_pool.Pool):
             put_message(job)
         self._quick_put = send_job
 
-        write_stats = self.write_stats = Counter()
-
-        def on_not_recovering(proc):
-            # XXX Theoretically a possibility, but not seen in practice yet.
-            raise Exception(
-                'Process writable but cannot write. Contact support!')
+        def on_not_recovering(proc, fd, job):
+            error('Process inqueue damaged: %r %r' % (proc, proc.exitcode))
+            if proc.exitcode is not None:
+                proc.terminate()
+            hub.remove(fd)
+            self._put_back(job)
 
         def _write_job(proc, fd, job):
             # writes job to the worker process.
@@ -636,10 +682,11 @@ class AsynPool(_pool.Pool):
                         # suspend until more data
                         errors += 1
                         if errors > 100:
-                            on_not_recovering(proc)
+                            on_not_recovering(proc, fd, job)
                             raise StopIteration()
                         yield
-                    errors = 0
+                    else:
+                        errors = 0
 
                 # write body
                 while Bw < body_size:
@@ -651,10 +698,11 @@ class AsynPool(_pool.Pool):
                         # suspend until more data
                         errors += 1
                         if errors > 100:
-                            on_not_recovering(proc)
+                            on_not_recovering(proc, fd, job)
                             raise StopIteration()
                         yield
-                    errors = 0
+                    else:
+                        errors = 0
             finally:
                 write_stats[proc.index] += 1
                 # message written, so this fd is now available
@@ -734,11 +782,11 @@ class AsynPool(_pool.Pool):
             if self._state == RUN:
                 # flush outgoing buffers
                 intervals = fxrange(0.01, 0.1, 0.01, repeatlast=True)
-                owned_by = defaultdict(set)
+                owned_by = {}
                 for job in values(self._cache):
                     writer = _get_job_writer(job)
                     if writer is not None:
-                        owned_by[writer].add(job)
+                        owned_by[writer] = job
 
                 while self._active_writers:
                     writers = list(self._active_writers)
@@ -748,17 +796,22 @@ class AsynPool(_pool.Pool):
                             # has not started writing the job so can
                             # discard the task, but we must also remove
                             # it from the Pool._cache.
-                            for job in owned_by[gen]:
+                            try:
+                                job = owned_by[gen]
+                            except KeyError:
+                                pass
+                            else:
                                 # removes from Pool._cache
                                 job.discard()
                             self._active_writers.discard(gen)
                         else:
-                            terminated = set()
-                            for job in owned_by[gen]:
-                                if job._write_to.exitcode is not None:
-                                    terminated.add(job)
-                            if len(terminated) < len(owned_by[gen]):
-                                self._flush_writer(gen)
+                            try:
+                                job = owned_by[gen]
+                            except KeyError:
+                                pass
+                            else:
+                                if job._write_to.exitcode is None:
+                                    self._flush_writer(gen)
                     # workers may have exited in the meantime.
                     self.maintain_pool()
                     sleep(next(intervals))  # don't busyloop
@@ -770,11 +823,7 @@ class AsynPool(_pool.Pool):
 
     def _flush_writer(self, writer):
         try:
-            while 1:
-                try:
-                    next(writer)
-                except StopIteration:
-                    break
+            list(writer)
         except (OSError, IOError) as exc:
             if get_errno(exc) != errno.EBADF:
                 raise
@@ -807,6 +856,7 @@ class AsynPool(_pool.Pool):
         returned as a tuple."""
         inq, outq, synq = _SimpleQueue(), _SimpleQueue(), None
         inq._writer.setblocking(0)
+        outq._reader.setblocking(0)
         if self.synack:
             synq = _SimpleQueue()
             synq._writer.setblocking(0)
@@ -859,6 +909,10 @@ class AsynPool(_pool.Pool):
             'avg': per(total / len(self.write_stats) if total else 0, total),
             'all': ', '.join(per(v, total) for v in vals),
             'raw': ', '.join(map(str, vals)),
+            'inqueues': {
+                'total': len(self._all_inqueues),
+                'active': len(self._active_writes),
+            }
         }
 
     def _process_cleanup_queues(self, proc):
@@ -923,6 +977,7 @@ class AsynPool(_pool.Pool):
         resq = proc.outq._reader
         on_state_change = self._result_handler.on_state_change
         while not resq.closed and resq.poll(0) and self._state != TERMINATE:
+            resq.setblocking(1)
             try:
                 task = resq.recv()
             except (IOError, EOFError) as exc:
@@ -934,6 +989,8 @@ class AsynPool(_pool.Pool):
                     on_state_change(task)
                 else:
                     debug('got sentinel while flushing process %r', proc)
+            finally:
+                resq.setblocking(0)
 
     def on_partial_read(self, job, proc):
         """Called when a job was only partially written to a child process